Calculates and returns the following values: - `num_train_epochs` - `num_update_steps_per_epoch` - `num_examples` - `num_train_samples` - `total_train_batch_size` - `steps_in_epoch` (total batches per epoch) - `max_steps`
(
self, args: TrainingArguments, dataloader: DataLoader
)
| 2309 | return contextlib.nullcontext, inputs |
| 2310 | |
| 2311 | def set_initial_training_values( |
| 2312 | self, args: TrainingArguments, dataloader: DataLoader |
| 2313 | ) -> tuple[int, int, int, int, int, int | None, int]: |
| 2314 | """ |
| 2315 | Calculates and returns the following values: |
| 2316 | - `num_train_epochs` |
| 2317 | - `num_update_steps_per_epoch` |
| 2318 | - `num_examples` |
| 2319 | - `num_train_samples` |
| 2320 | - `total_train_batch_size` |
| 2321 | - `steps_in_epoch` (total batches per epoch) |
| 2322 | - `max_steps` |
| 2323 | """ |
| 2324 | # Case 1: we rely on `args.max_steps` first |
| 2325 | max_steps = args.max_steps |
| 2326 | # If max_steps is negative, we use the number of epochs to determine the number of total steps later |
| 2327 | epoch_based = max_steps < 0 |
| 2328 | len_dataloader = len(dataloader) if has_length(dataloader) else None |
| 2329 | total_train_batch_size = self.get_total_train_batch_size(args) |
| 2330 | |
| 2331 | # Account for Sequence Parallelism (SP) dataloader adapter's effect |
| 2332 | sp_size = self.get_sp_size() |
| 2333 | if sp_size > 1 and len_dataloader is not None: |
| 2334 | len_dataloader = len_dataloader * sp_size |
| 2335 | |
| 2336 | # Case 2: We have a dataloader length and can extrapolate |
| 2337 | if len_dataloader is not None: |
| 2338 | num_update_steps_per_epoch = max( |
| 2339 | len_dataloader // args.gradient_accumulation_steps |
| 2340 | + int(len_dataloader % args.gradient_accumulation_steps > 0), |
| 2341 | 1, |
| 2342 | ) |
| 2343 | # Case 3: We have a length but are using epochs, we can extrapolate the number of steps |
| 2344 | if epoch_based: |
| 2345 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
| 2346 | # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples` |
| 2347 | if len_dataloader: |
| 2348 | num_examples = self.num_examples(dataloader) |
| 2349 | if args.max_steps > 0: |
| 2350 | num_train_epochs = max_steps // num_update_steps_per_epoch + int( |
| 2351 | max_steps % num_update_steps_per_epoch > 0 |
| 2352 | ) |
| 2353 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's |
| 2354 | # the best we can do. |
| 2355 | num_train_samples = max_steps * total_train_batch_size |
| 2356 | else: |
| 2357 | num_train_epochs = math.ceil(args.num_train_epochs) |
| 2358 | num_train_samples = self.num_examples(dataloader) * args.num_train_epochs |
| 2359 | elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size |
| 2360 | # Setting a very large number of epochs so we go as many times as necessary over the iterator. |
| 2361 | num_train_epochs = sys.maxsize |
| 2362 | num_update_steps_per_epoch = max_steps |
| 2363 | num_examples = total_train_batch_size * args.max_steps |
| 2364 | num_train_samples = args.max_steps * total_train_batch_size |
| 2365 | else: |
| 2366 | raise ValueError( |
| 2367 | "args.max_steps must be set to a positive value if dataloader does not have a length, was" |
| 2368 | f" {args.max_steps}" |