MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/translation/run_translation.py:274–682  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

272
273
274def main():
275 # See all possible arguments in src/transformers/training_args.py
276 # or by passing the --help flag to this script.
277 # We now keep distinct sets of args, for a cleaner separation of concerns.
278
279 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
280 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
281 # If we pass only one argument to the script and it's the path to a json file,
282 # let's parse it to get our arguments.
283 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
284 else:
285 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
286
287 # Setup logging
288 logging.basicConfig(
289 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
290 datefmt="%m/%d/%Y %H:%M:%S",
291 handlers=[logging.StreamHandler(sys.stdout)],
292 )
293
294 if training_args.should_log:
295 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
296 transformers.utils.logging.set_verbosity_info()
297
298 log_level = training_args.get_process_log_level()
299 logger.setLevel(log_level)
300 datasets.utils.logging.set_verbosity(log_level)
301 transformers.utils.logging.set_verbosity(log_level)
302 transformers.utils.logging.enable_default_handler()
303 transformers.utils.logging.enable_explicit_format()
304
305 # Log on each process the small summary:
306 logger.warning(
307 f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
308 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
309 )
310 logger.info(f"Training/evaluation parameters {training_args}")
311
312 if data_args.source_prefix is None and model_args.model_name_or_path in [
313 "google-t5/t5-small",
314 "google-t5/t5-base",
315 "google-t5/t5-large",
316 "google-t5/t5-3b",
317 "google-t5/t5-11b",
318 ]:
319 logger.warning(
320 "You're running a t5 model but didn't provide a source prefix, which is expected, e.g. with "
321 "`--source_prefix 'translate English to German: ' `"
322 )
323
324 # Set seed before initializing model.
325 set_seed(training_args.seed)
326
327 # Get the datasets: you can either provide your own JSON training and evaluation files (see below)
328 # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
329 # (the dataset will be downloaded automatically from the datasets Hub).
330 #
331 # For translation, only JSON files are supported, with one field named "translation" containing two keys for the

Callers 2

_mp_fnFunction · 0.70
run_translation.pyFile · 0.70

Calls 15

parse_json_fileMethod · 0.95
evaluateMethod · 0.95
predictMethod · 0.95
HfArgumentParserClass · 0.90
set_seedFunction · 0.90
Seq2SeqTrainerClass · 0.90
get_process_log_levelMethod · 0.80
setLevelMethod · 0.80
warningMethod · 0.80
splitMethod · 0.80

Tested by

no test coverage detected