MCPcopy
hub / github.com/huggingface/transformers / main

Function main

examples/pytorch/text-classification/run_classification.py:284–735  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

282
283
284def main():
285 # See all possible arguments in src/transformers/training_args.py
286 # or by passing the --help flag to this script.
287 # We now keep distinct sets of args, for a cleaner separation of concerns.
288
289 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
290 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
291 # If we pass only one argument to the script and it's the path to a json file,
292 # let's parse it to get our arguments.
293 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
294 else:
295 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
296
297 # Setup logging
298 logging.basicConfig(
299 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
300 datefmt="%m/%d/%Y %H:%M:%S",
301 handlers=[logging.StreamHandler(sys.stdout)],
302 )
303
304 if training_args.should_log:
305 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
306 transformers.utils.logging.set_verbosity_info()
307
308 log_level = training_args.get_process_log_level()
309 logger.setLevel(log_level)
310 datasets.utils.logging.set_verbosity(log_level)
311 transformers.utils.logging.set_verbosity(log_level)
312 transformers.utils.logging.enable_default_handler()
313 transformers.utils.logging.enable_explicit_format()
314
315 # Log on each process the small summary:
316 logger.warning(
317 f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
318 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
319 )
320 logger.info(f"Training/evaluation parameters {training_args}")
321
322 # Set seed before initializing model.
323 set_seed(training_args.seed)
324
325 # Get the datasets: you can either provide your own CSV/JSON training and evaluation files, or specify a dataset name
326 # to load from huggingface/datasets. In ether case, you can specify a the key of the column(s) containing the text and
327 # the key of the column containing the label. If multiple columns are specified for the text, they will be joined together
328 # for the actual text value.
329 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
330 # download the dataset.
331 if data_args.dataset_name is not None:
332 # Downloading and loading a dataset from the hub.
333 raw_datasets = load_dataset(
334 data_args.dataset_name,
335 data_args.dataset_config_name,
336 cache_dir=model_args.cache_dir,
337 token=model_args.token,
338 trust_remote_code=model_args.trust_remote_code,
339 )
340 # Try print some info about the dataset
341 logger.info(f"Dataset loaded: {raw_datasets}")

Callers 2

_mp_fnFunction · 0.70

Calls 15

parse_json_fileMethod · 0.95
trainMethod · 0.95
save_modelMethod · 0.95
evaluateMethod · 0.95
predictMethod · 0.95
is_world_process_zeroMethod · 0.95
push_to_hubMethod · 0.95
create_model_cardMethod · 0.95
HfArgumentParserClass · 0.90
set_seedFunction · 0.90

Tested by

no test coverage detected