(self)
| 451 | """Tests for NEFTune noise injection during training.""" |
| 452 | |
| 453 | def test_neftune(self): |
| 454 | config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) |
| 455 | tiny_gpt2 = GPT2LMHeadModel(config) |
| 456 | x = torch.randint(0, 100, (128,)) |
| 457 | train_dataset = RepeatDataset(x) |
| 458 | |
| 459 | # Trainer without inf/nan filter |
| 460 | args = TrainingArguments( |
| 461 | self.get_auto_remove_tmp_dir(), |
| 462 | learning_rate=1e-9, |
| 463 | logging_steps=5, |
| 464 | logging_nan_inf_filter=False, |
| 465 | neftune_noise_alpha=0.4, |
| 466 | ) |
| 467 | trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) |
| 468 | |
| 469 | activate_neftune(trainer.model, trainer.args.neftune_noise_alpha) |
| 470 | |
| 471 | dummy_input = torch.LongTensor([[1, 0, 1]]).to(torch_device) |
| 472 | |
| 473 | emb1 = trainer.model.get_input_embeddings()(dummy_input) |
| 474 | emb2 = trainer.model.get_input_embeddings()(dummy_input) |
| 475 | |
| 476 | self.assertFalse(torch.allclose(emb1, emb2), "Neftune noise is not applied!") |
| 477 | |
| 478 | # redefine the model |
| 479 | tiny_gpt2 = GPT2LMHeadModel(config) |
| 480 | # Trainer without inf/nan filter |
| 481 | args = TrainingArguments( |
| 482 | self.get_auto_remove_tmp_dir(), |
| 483 | learning_rate=1e-9, |
| 484 | logging_steps=5, |
| 485 | logging_nan_inf_filter=False, |
| 486 | neftune_noise_alpha=0.4, |
| 487 | ) |
| 488 | trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) |
| 489 | |
| 490 | # Check that it trains without errors |
| 491 | trainer.train() |
| 492 | |
| 493 | # Make sure forward pass works fine |
| 494 | _ = trainer.model(dummy_input) |
| 495 | self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) |
| 496 | |
| 497 | trainer.model.eval() |
| 498 | |
| 499 | # Check that we get identical embeddings just in case |
| 500 | emb1 = trainer.model.get_input_embeddings()(dummy_input) |
| 501 | emb2 = trainer.model.get_input_embeddings()(dummy_input) |
| 502 | torch.testing.assert_close(emb1, emb2) |
| 503 | |
| 504 | |
| 505 | # --------------------------------------------------------------------------- |
nothing calls this directly
no test coverage detected