(
model,
valid_dataloader,
post_process_class,
eval_class,
model_type=None,
extra_input=False,
scaler=None,
amp_level="O2",
amp_custom_black_list=[],
amp_custom_white_list=[],
amp_dtype="float16",
)
| 707 | |
| 708 | |
| 709 | def eval( |
| 710 | model, |
| 711 | valid_dataloader, |
| 712 | post_process_class, |
| 713 | eval_class, |
| 714 | model_type=None, |
| 715 | extra_input=False, |
| 716 | scaler=None, |
| 717 | amp_level="O2", |
| 718 | amp_custom_black_list=[], |
| 719 | amp_custom_white_list=[], |
| 720 | amp_dtype="float16", |
| 721 | ): |
| 722 | model.eval() |
| 723 | with paddle.no_grad(): |
| 724 | total_frame = 0.0 |
| 725 | total_time = 0.0 |
| 726 | pbar = tqdm( |
| 727 | total=len(valid_dataloader), desc="eval model:", position=0, leave=True |
| 728 | ) |
| 729 | max_iter = ( |
| 730 | len(valid_dataloader) - 1 |
| 731 | if platform.system() == "Windows" |
| 732 | else len(valid_dataloader) |
| 733 | ) |
| 734 | sum_images = 0 |
| 735 | for idx, batch in enumerate(valid_dataloader): |
| 736 | if idx >= max_iter: |
| 737 | break |
| 738 | images = batch[0] |
| 739 | start = time.time() |
| 740 | |
| 741 | # use amp |
| 742 | if scaler: |
| 743 | with paddle.amp.auto_cast( |
| 744 | level=amp_level, |
| 745 | custom_black_list=amp_custom_black_list, |
| 746 | dtype=amp_dtype, |
| 747 | ): |
| 748 | if model_type == "table" or extra_input: |
| 749 | preds = model(images, data=batch[1:]) |
| 750 | elif model_type in ["kie"]: |
| 751 | preds = model(batch) |
| 752 | elif model_type in ["can"]: |
| 753 | preds = model(batch[:3]) |
| 754 | elif model_type in ["latexocr"]: |
| 755 | preds = model(batch) |
| 756 | elif model_type in ["sr"]: |
| 757 | preds = model(batch) |
| 758 | sr_img = preds["sr_img"] |
| 759 | lr_img = preds["lr_img"] |
| 760 | else: |
| 761 | preds = model(images) |
| 762 | preds = to_float32(preds) |
| 763 | else: |
| 764 | if model_type == "table" or extra_input: |
| 765 | preds = model(images, data=batch[1:]) |
| 766 | elif model_type in ["kie"]: |
no test coverage detected
searching dependent graphs…