()
| 336 | |
| 337 | |
| 338 | def main(): |
| 339 | # See all possible arguments in src/transformers/training_args.py |
| 340 | # or by passing the --help flag to this script. |
| 341 | # We now keep distinct sets of args, for a cleaner separation of concerns. |
| 342 | |
| 343 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
| 344 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| 345 | # If we pass only one argument to the script and it's the path to a json file, |
| 346 | # let's parse it to get our arguments. |
| 347 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| 348 | else: |
| 349 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 350 | |
| 351 | # Setup logging |
| 352 | logging.basicConfig( |
| 353 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 354 | datefmt="%m/%d/%Y %H:%M:%S", |
| 355 | handlers=[logging.StreamHandler(sys.stdout)], |
| 356 | ) |
| 357 | |
| 358 | if training_args.should_log: |
| 359 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. |
| 360 | transformers.utils.logging.set_verbosity_info() |
| 361 | |
| 362 | log_level = training_args.get_process_log_level() |
| 363 | logger.setLevel(log_level) |
| 364 | transformers.utils.logging.set_verbosity(log_level) |
| 365 | transformers.utils.logging.enable_default_handler() |
| 366 | transformers.utils.logging.enable_explicit_format() |
| 367 | |
| 368 | # Log on each process the small summary: |
| 369 | logger.warning( |
| 370 | f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " |
| 371 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" |
| 372 | ) |
| 373 | logger.info(f"Training/evaluation parameters {training_args}") |
| 374 | |
| 375 | # ------------------------------------------------------------------------------------------------ |
| 376 | # Load dataset, prepare splits |
| 377 | # ------------------------------------------------------------------------------------------------ |
| 378 | |
| 379 | dataset = load_dataset( |
| 380 | data_args.dataset_name, cache_dir=model_args.cache_dir, trust_remote_code=model_args.trust_remote_code |
| 381 | ) |
| 382 | |
| 383 | # If we don't have a validation split, split off a percentage of train as validation |
| 384 | data_args.train_val_split = None if "validation" in dataset else data_args.train_val_split |
| 385 | if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: |
| 386 | split = dataset["train"].train_test_split(data_args.train_val_split, seed=training_args.seed) |
| 387 | dataset["train"] = split["train"] |
| 388 | dataset["validation"] = split["test"] |
| 389 | |
| 390 | # Get dataset categories and prepare mappings for label_name <-> label_id |
| 391 | if isinstance(dataset["train"].features["objects"], dict): |
| 392 | categories = dataset["train"].features["objects"]["category"].feature.names |
| 393 | else: # (for old versions of `datasets` that used Sequence({...}) of the objects) |
| 394 | categories = dataset["train"].features["objects"].feature["category"].names |
| 395 | id2label = dict(enumerate(categories)) |
no test coverage detected