MCPcopy
hub / github.com/huggingface/transformers / test_neftune

Method test_neftune

tests/trainer/test_trainer.py:453–502  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

451 """Tests for NEFTune noise injection during training."""
452
453 def test_neftune(self):
454 config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
455 tiny_gpt2 = GPT2LMHeadModel(config)
456 x = torch.randint(0, 100, (128,))
457 train_dataset = RepeatDataset(x)
458
459 # Trainer without inf/nan filter
460 args = TrainingArguments(
461 self.get_auto_remove_tmp_dir(),
462 learning_rate=1e-9,
463 logging_steps=5,
464 logging_nan_inf_filter=False,
465 neftune_noise_alpha=0.4,
466 )
467 trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
468
469 activate_neftune(trainer.model, trainer.args.neftune_noise_alpha)
470
471 dummy_input = torch.LongTensor([[1, 0, 1]]).to(torch_device)
472
473 emb1 = trainer.model.get_input_embeddings()(dummy_input)
474 emb2 = trainer.model.get_input_embeddings()(dummy_input)
475
476 self.assertFalse(torch.allclose(emb1, emb2), "Neftune noise is not applied!")
477
478 # redefine the model
479 tiny_gpt2 = GPT2LMHeadModel(config)
480 # Trainer without inf/nan filter
481 args = TrainingArguments(
482 self.get_auto_remove_tmp_dir(),
483 learning_rate=1e-9,
484 logging_steps=5,
485 logging_nan_inf_filter=False,
486 neftune_noise_alpha=0.4,
487 )
488 trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
489
490 # Check that it trains without errors
491 trainer.train()
492
493 # Make sure forward pass works fine
494 _ = trainer.model(dummy_input)
495 self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
496
497 trainer.model.eval()
498
499 # Check that we get identical embeddings just in case
500 emb1 = trainer.model.get_input_embeddings()(dummy_input)
501 emb2 = trainer.model.get_input_embeddings()(dummy_input)
502 torch.testing.assert_close(emb1, emb2)
503
504
505# ---------------------------------------------------------------------------

Callers

nothing calls this directly

Calls 12

trainMethod · 0.95
GPT2ConfigClass · 0.90
GPT2LMHeadModelClass · 0.90
TrainingArgumentsClass · 0.90
TrainerClass · 0.90
activate_neftuneFunction · 0.90
RepeatDatasetClass · 0.85
evalMethod · 0.80
toMethod · 0.45
get_input_embeddingsMethod · 0.45
modelMethod · 0.45

Tested by

no test coverage detected