MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/object-detection/run_object_detection.py:338–516  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

336
337
338def main():
339 # See all possible arguments in src/transformers/training_args.py
340 # or by passing the --help flag to this script.
341 # We now keep distinct sets of args, for a cleaner separation of concerns.
342
343 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345 # If we pass only one argument to the script and it's the path to a json file,
346 # let's parse it to get our arguments.
347 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348 else:
349 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
351 # Setup logging
352 logging.basicConfig(
353 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
354 datefmt="%m/%d/%Y %H:%M:%S",
355 handlers=[logging.StreamHandler(sys.stdout)],
356 )
357
358 if training_args.should_log:
359 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
360 transformers.utils.logging.set_verbosity_info()
361
362 log_level = training_args.get_process_log_level()
363 logger.setLevel(log_level)
364 transformers.utils.logging.set_verbosity(log_level)
365 transformers.utils.logging.enable_default_handler()
366 transformers.utils.logging.enable_explicit_format()
367
368 # Log on each process the small summary:
369 logger.warning(
370 f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
371 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
372 )
373 logger.info(f"Training/evaluation parameters {training_args}")
374
375 # ------------------------------------------------------------------------------------------------
376 # Load dataset, prepare splits
377 # ------------------------------------------------------------------------------------------------
378
379 dataset = load_dataset(
380 data_args.dataset_name, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code
381 )
382
383 # If we don't have a validation split, split off a percentage of train as validation
384 data_args.train_val_split = None if "validation" in dataset else data_args.train_val_split
385 if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
386 split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed)
387 dataset["train"] = split["train"]
388 dataset["validation"] = split["test"]
389
390 # Get dataset categories and prepare mappings for label_name <-> label_id
391 if isinstance(dataset["train"].features["objects"], dict):
392 categories = dataset["train"].features["objects"]["category"].feature.names
393 else: # (for old versions of `datasets` that used Sequence({...}) of the objects)
394 categories = dataset["train"].features["objects"].feature["category"].names
395 id2label = dict(enumerate(categories))

Callers 1

Calls 15

parse_json_fileMethod · 0.95
trainMethod · 0.95
save_modelMethod · 0.95
evaluateMethod · 0.95
push_to_hubMethod · 0.95
create_model_cardMethod · 0.95
HfArgumentParserClass · 0.90
TrainerClass · 0.90
get_process_log_levelMethod · 0.80
setLevelMethod · 0.80
warningMethod · 0.80

Tested by

no test coverage detected