MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/language-modeling/run_fim_no_trainer.py:330–914  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

328
329
330def main():
331 args = parse_args()
332
333 # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
334 # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
335 # in the environment
336 accelerator_log_kwargs = {}
337
338 if args.with_tracking:
339 accelerator_log_kwargs["log_with"] = args.report_to
340 accelerator_log_kwargs["project_dir"] = args.output_dir
341
342 accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
343
344 # Make one log on every process with the configuration for debugging.
345 logging.basicConfig(
346 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
347 datefmt="%m/%d/%Y %H:%M:%S",
348 level=logging.INFO,
349 )
350 logger.info(accelerator.state, main_process_only=False)
351 if accelerator.is_local_main_process:
352 datasets.utils.logging.set_verbosity_warning()
353 transformers.utils.logging.set_verbosity_info()
354 else:
355 datasets.utils.logging.set_verbosity_error()
356 transformers.utils.logging.set_verbosity_error()
357
358 # If passed along, set the training seed now.
359 if args.seed is not None:
360 set_seed(args.seed)
361 # Set a numpy random state for FIM transformations
362 np_rng = np.random.RandomState(seed=args.seed)
363 else:
364 # Still set a random state for FIM transformations
365 np_rng = np.random.RandomState(seed=42)
366
367 # Handle the repository creation
368 if accelerator.is_main_process:
369 if args.push_to_hub:
370 # Retrieve of infer repo_name
371 repo_name = args.hub_model_id
372 if repo_name is None:
373 repo_name = Path(args.output_dir).absolute().name
374 # Create repo and retrieve repo_id
375 repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
376 # Clone repo locally
377 repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
378
379 with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
380 if "step_*" not in gitignore:
381 gitignore.write("step_*\n")
382 if "epoch_*" not in gitignore:
383 gitignore.write("epoch_*\n")
384 elif args.output_dir is not None:
385 os.makedirs(args.output_dir, exist_ok=True)
386 accelerator.wait_for_everyone()
387

Callers 1

Calls 15

set_seedFunction · 0.90
is_torch_xla_availableFunction · 0.90
get_parameter_namesFunction · 0.90
get_schedulerFunction · 0.90
joinMethod · 0.80
splitMethod · 0.80
warningMethod · 0.80
main_process_firstMethod · 0.80
accumulateMethod · 0.80
floatMethod · 0.80
detachMethod · 0.80

Tested by

no test coverage detected