Create all models for a certain model type. Args: config_class (`PreTrainedConfig`): A subclass of `PreTrainedConfig` that is used to determine `models_to_create`. models_to_create (`dict`): A dictionary containing the processor/model classes that we want
(config_class, models_to_create, output_dir, keep_model=False)
| 1332 | |
| 1333 | |
| 1334 | def build(config_class, models_to_create, output_dir, keep_model=False): |
| 1335 | """Create all models for a certain model type. |
| 1336 | |
| 1337 | Args: |
| 1338 | config_class (`PreTrainedConfig`): |
| 1339 | A subclass of `PreTrainedConfig` that is used to determine `models_to_create`. |
| 1340 | models_to_create (`dict`): |
| 1341 | A dictionary containing the processor/model classes that we want to create the instances. These models are |
| 1342 | of the same model type which is associated to `config_class`. |
| 1343 | output_dir (`str`): |
| 1344 | The directory to save all the checkpoints. Each model architecture will be saved in a subdirectory under |
| 1345 | it. |
| 1346 | """ |
| 1347 | if data["training_ds"] is None or data["testing_ds"] is None: |
| 1348 | ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1") |
| 1349 | data["training_ds"] = ds["train"] |
| 1350 | data["testing_ds"] = ds["test"] |
| 1351 | |
| 1352 | if config_class.model_type in [ |
| 1353 | "encoder-decoder", |
| 1354 | "vision-encoder-decoder", |
| 1355 | "speech-encoder-decoder", |
| 1356 | "vision-text-dual-encoder", |
| 1357 | ]: |
| 1358 | return build_composite_models(config_class, output_dir) |
| 1359 | |
| 1360 | result = {k: {} for k in models_to_create} |
| 1361 | |
| 1362 | # These will be removed at the end if they are empty |
| 1363 | result["error"] = None |
| 1364 | result["warnings"] = [] |
| 1365 | |
| 1366 | # Build processors |
| 1367 | processor_classes = models_to_create["processor"] |
| 1368 | |
| 1369 | # AutoTokenizer can't load from hub repo ... |
| 1370 | if config_class.__name__ in ["FastSpeech2ConformerWithHifiGanConfig"]: |
| 1371 | processor_classes = (FastSpeech2ConformerTokenizer,) + processor_classes |
| 1372 | |
| 1373 | if len(processor_classes) == 0: |
| 1374 | error = f"No processor class could be found in {config_class.__name__}." |
| 1375 | fill_result_with_error(result, error, None, models_to_create) |
| 1376 | logger.error(result["error"][0]) |
| 1377 | processor_names = [p.__name__ if not isinstance(p, str) else p for p in result["processor"]] |
| 1378 | result["processor"] = {p: p for p in processor_names} |
| 1379 | |
| 1380 | return result |
| 1381 | |
| 1382 | traces = [] |
| 1383 | errors = [] |
| 1384 | for processor_class in processor_classes: |
| 1385 | try: |
| 1386 | processor = build_processor(config_class, processor_class, allow_no_checkpoint=True) |
| 1387 | if processor is not None: |
| 1388 | if type(processor) not in result["processor"]: |
| 1389 | result["processor"][type(processor)] = processor |
| 1390 | except Exception: |
| 1391 | error = f"Failed to build processor for {processor_class.__name__}." |
no test coverage detected