MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/image-pretraining/run_mae.py:181–384  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

179
180
181def main():
182 # See all possible arguments in src/transformers/training_args.py
183 # or by passing the --help flag to this script.
184 # We now keep distinct sets of args, for a cleaner separation of concerns.
185
186 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments))
187 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
188 # If we pass only one argument to the script and it's the path to a json file,
189 # let's parse it to get our arguments.
190 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
191 else:
192 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
193
194 # Setup logging
195 logging.basicConfig(
196 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
197 datefmt="%m/%d/%Y %H:%M:%S",
198 handlers=[logging.StreamHandler(sys.stdout)],
199 )
200
201 if training_args.should_log:
202 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
203 transformers.utils.logging.set_verbosity_info()
204
205 log_level = training_args.get_process_log_level()
206 logger.setLevel(log_level)
207 transformers.utils.logging.set_verbosity(log_level)
208 transformers.utils.logging.enable_default_handler()
209 transformers.utils.logging.enable_explicit_format()
210
211 # Log on each process the small summary:
212 logger.warning(
213 f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
214 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
215 )
216 logger.info(f"Training/evaluation parameters {training_args}")
217
218 # Initialize our dataset.
219 ds = load_dataset(
220 data_args.dataset_name,
221 data_args.dataset_config_name,
222 data_files=data_args.data_files,
223 cache_dir=model_args.cache_dir,
224 token=model_args.token,
225 trust_remote_code=data_args.trust_remote_code,
226 )
227
228 # If we don't have a validation split, split off a percentage of train as validation.
229 data_args.train_val_split = None if "validation" in ds else data_args.train_val_split
230 if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
231 split = ds["train"].train_test_split(data_args.train_val_split)
232 ds["train"] = split["train"]
233 ds["validation"] = split["test"]
234
235 # Load pretrained model and image processor
236 #
237 # Distributed training:
238 # The .from_pretrained methods guarantee that only one local process can concurrently

Callers 2

_mp_fnFunction · 0.70
run_mae.pyFile · 0.70

Calls 15

parse_json_fileMethod · 0.95
trainMethod · 0.95
save_modelMethod · 0.95
evaluateMethod · 0.95
push_to_hubMethod · 0.95
create_model_cardMethod · 0.95
HfArgumentParserClass · 0.90
ViTMAEConfigClass · 0.90
ViTImageProcessorClass · 0.90
TrainerClass · 0.90

Tested by

no test coverage detected