| 3669 | |
| 3670 | @wraps(torch.nn.Module.to) |
| 3671 | def to(self, *args, **kwargs): |
| 3672 | class="cm"># For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours. |
| 3673 | class="cm"># the correct API should be to load the model with the desired dtype directly through `from_pretrained`. |
| 3674 | dtype_present_in_args = class="st">"dtype" in kwargs |
| 3675 | |
| 3676 | if not dtype_present_in_args: |
| 3677 | for arg in args: |
| 3678 | if isinstance(arg, torch.dtype): |
| 3679 | dtype_present_in_args = True |
| 3680 | break |
| 3681 | |
| 3682 | if getattr(self, class="st">"quantization_method", None) == QuantizationMethod.HQQ: |
| 3683 | from hqq.core.quantize import HQQLinear |
| 3684 | |
| 3685 | class="cm"># Since HQQLinear stores some tensors in the class="st">'meta' attribute, we must |
| 3686 | class="cm"># explicitly move the parameters to the target device for each HQQLinear layer after `to`. |
| 3687 | super().to(*args, **kwargs) |
| 3688 | for module in self.modules(): |
| 3689 | if isinstance(module, HQQLinear): |
| 3690 | if class="st">"device" in kwargs: |
| 3691 | device = kwargs[class="st">"device"] |
| 3692 | else: |
| 3693 | device = args[0] |
| 3694 | if class="st">"dtype" in kwargs: |
| 3695 | dtype = kwargs[class="st">"dtype"] |
| 3696 | elif dtype_present_in_args: |
| 3697 | dtype = arg |
| 3698 | else: |
| 3699 | dtype = None |
| 3700 | class="cm"># Due to the current messy implementation of HQQLinear, updating `compute_dtype` |
| 3701 | class="cm"># followed by calling the `cuda` method achieves the intended behavior of `to`, |
| 3702 | class="cm"># even when the target device is CPU. |
| 3703 | if dtype is not None: |
| 3704 | module.compute_dtype = dtype |
| 3705 | module.cuda(device) |
| 3706 | return self |
| 3707 | |
| 3708 | if dtype_present_in_args and getattr(self, class="st">"quantization_method", None) == QuantizationMethod.QUARK: |
| 3709 | raise ValueError(class="st">"Casting a Quark quantized model to a new `dtype` is not supported.") |
| 3710 | |
| 3711 | class="cm"># Checks if the model has been loaded in 4-bit or 8-bit with BNB |
| 3712 | if getattr(self, class="st">"quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: |
| 3713 | if dtype_present_in_args: |
| 3714 | raise ValueError( |
| 3715 | class="st">"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" |
| 3716 | class="st">" desired `dtype` by passing the correct `dtype` argument." |
| 3717 | ) |
| 3718 | |
| 3719 | if getattr(self, class="st">"is_loaded_in_8bit", False) and not is_bitsandbytes_available(class="st">"0.48"): |
| 3720 | raise ValueError( |
| 3721 | class="st">"You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()." |
| 3722 | ) |
| 3723 | elif getattr(self, class="st">"quantization_method", None) == QuantizationMethod.GPTQ: |
| 3724 | if dtype_present_in_args: |
| 3725 | raise ValueError( |
| 3726 | class="st">"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" |
| 3727 | class="st">" `dtype` by passing the correct `dtype` argument." |
| 3728 | ) |