()
| 328 | |
| 329 | |
| 330 | def main(): |
| 331 | args = parse_args() |
| 332 | |
| 333 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. |
| 334 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers |
| 335 | # in the environment |
| 336 | accelerator_log_kwargs = {} |
| 337 | |
| 338 | if args.with_tracking: |
| 339 | accelerator_log_kwargs["log_with"] = args.report_to |
| 340 | accelerator_log_kwargs["project_dir"] = args.output_dir |
| 341 | |
| 342 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) |
| 343 | |
| 344 | # Make one log on every process with the configuration for debugging. |
| 345 | logging.basicConfig( |
| 346 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 347 | datefmt="%m/%d/%Y %H:%M:%S", |
| 348 | level=logging.INFO, |
| 349 | ) |
| 350 | logger.info(accelerator.state, main_process_only=False) |
| 351 | if accelerator.is_local_main_process: |
| 352 | datasets.utils.logging.set_verbosity_warning() |
| 353 | transformers.utils.logging.set_verbosity_info() |
| 354 | else: |
| 355 | datasets.utils.logging.set_verbosity_error() |
| 356 | transformers.utils.logging.set_verbosity_error() |
| 357 | |
| 358 | # If passed along, set the training seed now. |
| 359 | if args.seed is not None: |
| 360 | set_seed(args.seed) |
| 361 | # Set a numpy random state for FIM transformations |
| 362 | np_rng = np.random.RandomState(seed=args.seed) |
| 363 | else: |
| 364 | # Still set a random state for FIM transformations |
| 365 | np_rng = np.random.RandomState(seed=42) |
| 366 | |
| 367 | # Handle the repository creation |
| 368 | if accelerator.is_main_process: |
| 369 | if args.push_to_hub: |
| 370 | # Retrieve of infer repo_name |
| 371 | repo_name = args.hub_model_id |
| 372 | if repo_name is None: |
| 373 | repo_name = Path(args.output_dir).absolute().name |
| 374 | # Create repo and retrieve repo_id |
| 375 | repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id |
| 376 | # Clone repo locally |
| 377 | repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token) |
| 378 | |
| 379 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: |
| 380 | if "step_*" not in gitignore: |
| 381 | gitignore.write("step_*\n") |
| 382 | if "epoch_*" not in gitignore: |
| 383 | gitignore.write("epoch_*\n") |
| 384 | elif args.output_dir is not None: |
| 385 | os.makedirs(args.output_dir, exist_ok=True) |
| 386 | accelerator.wait_for_everyone() |
| 387 |
no test coverage detected