Test if the quantized model is working properly with "auto". cpu/disk offloading as well doesn't work with bnb.
(self)
| 580 | del pipeline_4bit |
| 581 | |
| 582 | def test_device_map(self): |
| 583 | """ |
| 584 | Test if the quantized model is working properly with "auto". |
| 585 | cpu/disk offloading as well doesn't work with bnb. |
| 586 | """ |
| 587 | |
| 588 | def get_dummy_tensor_inputs(device=None, seed: int = 0): |
| 589 | batch_size = 1 |
| 590 | num_latent_channels = 4 |
| 591 | num_image_channels = 3 |
| 592 | height = width = 4 |
| 593 | sequence_length = 48 |
| 594 | embedding_dim = 32 |
| 595 | |
| 596 | torch.manual_seed(seed) |
| 597 | hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to( |
| 598 | device, dtype=torch.bfloat16 |
| 599 | ) |
| 600 | torch.manual_seed(seed) |
| 601 | encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( |
| 602 | device, dtype=torch.bfloat16 |
| 603 | ) |
| 604 | |
| 605 | torch.manual_seed(seed) |
| 606 | pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) |
| 607 | |
| 608 | torch.manual_seed(seed) |
| 609 | text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) |
| 610 | |
| 611 | torch.manual_seed(seed) |
| 612 | image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) |
| 613 | |
| 614 | timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) |
| 615 | |
| 616 | return { |
| 617 | "hidden_states": hidden_states, |
| 618 | "encoder_hidden_states": encoder_hidden_states, |
| 619 | "pooled_projections": pooled_prompt_embeds, |
| 620 | "txt_ids": text_ids, |
| 621 | "img_ids": image_ids, |
| 622 | "timestep": timestep, |
| 623 | } |
| 624 | |
| 625 | inputs = get_dummy_tensor_inputs(torch_device) |
| 626 | expected_slice = np.array( |
| 627 | [0.47070312, 0.00390625, -0.03662109, -0.19628906, -0.53125, 0.5234375, -0.17089844, -0.59375, 0.578125] |
| 628 | ) |
| 629 | |
| 630 | # non sharded |
| 631 | quantization_config = BitsAndBytesConfig( |
| 632 | load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 |
| 633 | ) |
| 634 | quantized_model = FluxTransformer2DModel.from_pretrained( |
| 635 | "hf-internal-testing/tiny-flux-pipe", |
| 636 | subfolder="transformer", |
| 637 | quantization_config=quantization_config, |
| 638 | device_map="auto", |
| 639 | torch_dtype=torch.bfloat16, |
nothing calls this directly
no test coverage detected