MCPcopy
hub / github.com/huggingface/transformers / __init__

Method __init__

src/transformers/data/datasets/glue.py:75–147  ·  view source on GitHub ↗
(
        self,
        args: GlueDataTrainingArguments,
        tokenizer: PreTrainedTokenizerBase,
        limit_length: int | None = None,
        mode: str | Split = Split.train,
        cache_dir: str | None = None,
    )

Source from the content-addressed store, hash-verified

73 features: list[InputFeatures]
74
75 def __init__(
76 self,
77 args: GlueDataTrainingArguments,
78 tokenizer: PreTrainedTokenizerBase,
79 limit_length: int | None = None,
80 mode: str | Split = Split.train,
81 cache_dir: str | None = None,
82 ):
83 warnings.warn(
84 "This dataset will be removed from the library soon, preprocessing should be handled with the Hugging Face Datasets "
85 "library. You can have a look at this example script for pointers: "
86 "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
87 FutureWarning,
88 )
89 self.args = args
90 self.processor = glue_processors[args.task_name]()
91 self.output_mode = glue_output_modes[args.task_name]
92 if isinstance(mode, str):
93 try:
94 mode = Split[mode]
95 except KeyError:
96 raise KeyError("mode is not a valid split name")
97 # Load data features from cache or dataset file
98 cached_features_file = os.path.join(
99 cache_dir if cache_dir is not None else args.data_dir,
100 f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
101 )
102 label_list = self.processor.get_labels()
103 if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
104 "RobertaTokenizer",
105 "XLMRobertaTokenizer",
106 "BartTokenizer",
107 "BartTokenizerFast",
108 ):
109 # HACK(label indices are swapped in RoBERTa pretrained model)
110 label_list[1], label_list[2] = label_list[2], label_list[1]
111 self.label_list = label_list
112
113 # Make sure only the first process in distributed training processes the dataset,
114 # and the others will use the cache.
115 lock_path = cached_features_file + ".lock"
116 with FileLock(lock_path):
117 if os.path.exists(cached_features_file) and not args.overwrite_cache:
118 start = time.time()
119 check_torch_load_is_safe()
120 self.features = torch.load(cached_features_file, weights_only=True)
121 logger.info(
122 f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
123 )
124 else:
125 logger.info(f"Creating features from dataset file at {args.data_dir}")
126
127 if mode == Split.dev:
128 examples = self.processor.get_dev_examples(args.data_dir)
129 elif mode == Split.test:
130 examples = self.processor.get_test_examples(args.data_dir)
131 else:
132 examples = self.processor.get_train_examples(args.data_dir)

Callers

nothing calls this directly

Calls 11

check_torch_load_is_safeFunction · 0.85
warnMethod · 0.80
joinMethod · 0.80
get_labelsMethod · 0.45
timeMethod · 0.45
infoMethod · 0.45
get_dev_examplesMethod · 0.45
get_test_examplesMethod · 0.45
get_train_examplesMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected