(self, *args, **kwargs)
| 3643 | |
| 3644 | @wraps(torch.nn.Module.cuda) |
| 3645 | def cuda(self, *args, **kwargs): |
| 3646 | if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: |
| 3647 | from hqq.core.quantize import HQQLinear |
| 3648 | |
| 3649 | # Since HQQLinear stores some tensors in the 'meta' attribute, |
| 3650 | # it's necessary to manually call the `cuda` method on HQQLinear layers. |
| 3651 | super().cuda(*args, **kwargs) |
| 3652 | for module in self.modules(): |
| 3653 | if isinstance(module, HQQLinear): |
| 3654 | if len(args) > 0: |
| 3655 | device = args[0] |
| 3656 | else: |
| 3657 | device = kwargs.get("device", "cuda") |
| 3658 | module.cuda(device) |
| 3659 | return self |
| 3660 | |
| 3661 | # Checks if the model has been loaded in 4-bit or 8-bit with BNB |
| 3662 | if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: |
| 3663 | if getattr(self, "is_loaded_in_8bit", False): |
| 3664 | raise ValueError( |
| 3665 | "Calling `cuda()` is not supported for `8-bit` quantized models. " |
| 3666 | " Please use the model as it is, since the model has already been set to the correct devices." |
| 3667 | ) |
| 3668 | return super().cuda(*args, **kwargs) |
| 3669 | |
| 3670 | @wraps(torch.nn.Module.to) |
| 3671 | def to(self, *args, **kwargs): |
no test coverage detected