MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/image-pretraining/run_mim.py:245–464  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

243
244
245def main():
246 # See all possible arguments in src/transformers/training_args.py
247 # or by passing the --help flag to this script.
248 # We now keep distinct sets of args, for a cleaner separation of concerns.
249
250 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
251 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
252 # If we pass only one argument to the script and it's the path to a json file,
253 # let's parse it to get our arguments.
254 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
255 else:
256 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
257
258 # Setup logging
259 logging.basicConfig(
260 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
261 datefmt="%m/%d/%Y %H:%M:%S",
262 handlers=[logging.StreamHandler(sys.stdout)],
263 )
264
265 if training_args.should_log:
266 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
267 transformers.utils.logging.set_verbosity_info()
268
269 log_level = training_args.get_process_log_level()
270 logger.setLevel(log_level)
271 transformers.utils.logging.set_verbosity(log_level)
272 transformers.utils.logging.enable_default_handler()
273 transformers.utils.logging.enable_explicit_format()
274
275 # Log on each process the small summary:
276 logger.warning(
277 f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
278 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
279 )
280 logger.info(f"Training/evaluation parameters {training_args}")
281
282 # Initialize our dataset.
283 ds = load_dataset(
284 data_args.dataset_name,
285 data_args.dataset_config_name,
286 data_files=data_args.data_files,
287 cache_dir=model_args.cache_dir,
288 token=model_args.token,
289 trust_remote_code=model_args.trust_remote_code,
290 )
291
292 # If we don't have a validation split, split off a percentage of train as validation.
293 data_args.train_val_split = None if "validation" in ds else data_args.train_val_split
294 if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
295 split = ds["train"].train_test_split(data_args.train_val_split)
296 ds["train"] = split["train"]
297 ds["validation"] = split["test"]
298
299 # Create config
300 # Distributed training:
301 # The .from_pretrained methods guarantee that only one local process can concurrently
302 # download model & vocab.

Callers 1

run_mim.pyFile · 0.70

Calls 15

parse_json_fileMethod · 0.95
trainMethod · 0.95
save_modelMethod · 0.95
evaluateMethod · 0.95
push_to_hubMethod · 0.95
create_model_cardMethod · 0.95
HfArgumentParserClass · 0.90
TrainerClass · 0.90
get_process_log_levelMethod · 0.80
setLevelMethod · 0.80
warningMethod · 0.80

Tested by

no test coverage detected