(self)
| 133 | |
| 134 | class BnB4BitBasicTests(Base4bitTests): |
| 135 | def setUp(self): |
| 136 | gc.collect() |
| 137 | backend_empty_cache(torch_device) |
| 138 | |
| 139 | # Models |
| 140 | self.model_fp16 = SD3Transformer2DModel.from_pretrained( |
| 141 | self.model_name, subfolder="transformer", torch_dtype=torch.float16 |
| 142 | ) |
| 143 | nf4_config = BitsAndBytesConfig( |
| 144 | load_in_4bit=True, |
| 145 | bnb_4bit_quant_type="nf4", |
| 146 | bnb_4bit_compute_dtype=torch.float16, |
| 147 | ) |
| 148 | self.model_4bit = SD3Transformer2DModel.from_pretrained( |
| 149 | self.model_name, subfolder="transformer", quantization_config=nf4_config, device_map=torch_device |
| 150 | ) |
| 151 | |
| 152 | def tearDown(self): |
| 153 | if hasattr(self, "model_fp16"): |
nothing calls this directly
no test coverage detected