MCPcopy
hub / github.com/huggingface/transformers / find_tested_models

Function find_tested_models

utils/check_repo.py:776–868  ·  view source on GitHub ↗

Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from the common test class. Args: test_file (`str`): The path to the test file to check Returns: `Set[str]`: The set of models tested in that file.

(test_file: str)

Source from the content-addressed store, hash-verified

774# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
775# for the all_model_classes variable.
776def find_tested_models(test_file: str) -> set[str]:
777 """
778 Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
779 the common test class.
780
781 Args:
782 test_file (`str`): The path to the test file to check
783
784 Returns:
785 `Set[str]`: The set of models tested in that file.
786 """
787 # TODO Matt: Some of the regexes here are ugly / hacky, and we can probably parse the content better.
788 # Also we should be clear about exactly what rules we're enforcing and which classes
789 # are actually mandatory.
790 with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
791 content = f.read()
792
793 model_tested = set()
794
795 all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
796 # Check with one less parenthesis as well
797 all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
798 if len(all_models) > 0:
799 for entry in all_models:
800 for line in entry.split(","):
801 name = line.strip()
802 if len(name) > 0:
803 model_tested.add(name)
804
805 # Models that inherit from `CausalLMModelTester` don't need to set `all_model_classes` -- it is built from other
806 # attributes by default.
807 if "CausalLMModelTester" in content:
808 base_model_class = re.findall(r"base_model_class\s+=.*", content) # Required attribute
809 base_class = base_model_class[0].split("=")[1].strip()
810 model_tested.add(base_class)
811
812 model_name = base_class.replace("Model", "")
813 # Optional attributes: if not set explicitly, the tester will attempt to infer and use the corresponding class
814 for test_class_type in [
815 "causal_lm_class",
816 "sequence_classification_class",
817 "question_answering_class",
818 "token_classification_class",
819 ]:
820 tested_class = re.findall(rf"{test_class_type}\s+=.*", content)
821 if tested_class:
822 tested_class = tested_class[0].split("=")[1].strip()
823 else:
824 tested_class = model_name + _COMMON_MODEL_NAMES_MAP[test_class_type]
825 model_tested.add(tested_class)
826 # Same as above, but for VLMModelTester. We scope the search to the VLMModelTester subclass body, as some
827 # files may contain both a CausalLMModelTester and a VLMModelTester (e.g. gemma3).
828 vlm_class_match = re.search(r"class \w+\(VLMModelTester\)", content)
829 if vlm_class_match is not None:
830 vlm_content = content[vlm_class_match.start() :]
831 base_model_class = re.findall(r"base_model_class\s+=.*", vlm_content) # Required attribute
832 base_class = base_model_class[0].split("=")[1].strip()
833 model_tested.add(base_class)

Callers 1

check_models_are_testedFunction · 0.85

Calls 4

joinMethod · 0.80
splitMethod · 0.80
addMethod · 0.45
startMethod · 0.45

Tested by

no test coverage detected