Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from the common test class. Args: test_file (`str`): The path to the test file to check Returns: `Set[str]`: The set of models tested in that file.
(test_file: str)
| 774 | # This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class |
| 775 | # for the all_model_classes variable. |
| 776 | def find_tested_models(test_file: str) -> set[str]: |
| 777 | """ |
| 778 | Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from |
| 779 | the common test class. |
| 780 | |
| 781 | Args: |
| 782 | test_file (`str`): The path to the test file to check |
| 783 | |
| 784 | Returns: |
| 785 | `Set[str]`: The set of models tested in that file. |
| 786 | """ |
| 787 | # TODO Matt: Some of the regexes here are ugly / hacky, and we can probably parse the content better. |
| 788 | # Also we should be clear about exactly what rules we're enforcing and which classes |
| 789 | # are actually mandatory. |
| 790 | with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f: |
| 791 | content = f.read() |
| 792 | |
| 793 | model_tested = set() |
| 794 | |
| 795 | all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) |
| 796 | # Check with one less parenthesis as well |
| 797 | all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content) |
| 798 | if len(all_models) > 0: |
| 799 | for entry in all_models: |
| 800 | for line in entry.split(","): |
| 801 | name = line.strip() |
| 802 | if len(name) > 0: |
| 803 | model_tested.add(name) |
| 804 | |
| 805 | # Models that inherit from `CausalLMModelTester` don't need to set `all_model_classes` -- it is built from other |
| 806 | # attributes by default. |
| 807 | if "CausalLMModelTester" in content: |
| 808 | base_model_class = re.findall(r"base_model_class\s+=.*", content) # Required attribute |
| 809 | base_class = base_model_class[0].split("=")[1].strip() |
| 810 | model_tested.add(base_class) |
| 811 | |
| 812 | model_name = base_class.replace("Model", "") |
| 813 | # Optional attributes: if not set explicitly, the tester will attempt to infer and use the corresponding class |
| 814 | for test_class_type in [ |
| 815 | "causal_lm_class", |
| 816 | "sequence_classification_class", |
| 817 | "question_answering_class", |
| 818 | "token_classification_class", |
| 819 | ]: |
| 820 | tested_class = re.findall(rf"{test_class_type}\s+=.*", content) |
| 821 | if tested_class: |
| 822 | tested_class = tested_class[0].split("=")[1].strip() |
| 823 | else: |
| 824 | tested_class = model_name + _COMMON_MODEL_NAMES_MAP[test_class_type] |
| 825 | model_tested.add(tested_class) |
| 826 | # Same as above, but for VLMModelTester. We scope the search to the VLMModelTester subclass body, as some |
| 827 | # files may contain both a CausalLMModelTester and a VLMModelTester (e.g. gemma3). |
| 828 | vlm_class_match = re.search(r"class \w+\(VLMModelTester\)", content) |
| 829 | if vlm_class_match is not None: |
| 830 | vlm_content = content[vlm_class_match.start() :] |
| 831 | base_model_class = re.findall(r"base_model_class\s+=.*", vlm_content) # Required attribute |
| 832 | base_class = base_model_class[0].split("=")[1].strip() |
| 833 | model_tested.add(base_class) |
no test coverage detected