MCPcopy
hub / github.com/huggingface/transformers / to

Method to

src/transformers/modeling_utils.py:3671–3729  ·  src/transformers/modeling_utils.py::PreTrainedModel.to
(self, *args, **kwargs)

Source from the content-addressed store, hash-verified

3669
3670 @wraps(torch.nn.Module.to)
3671 def to(self, *args, **kwargs):
3672 class="cm"># For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
3673 class="cm"># the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
3674 dtype_present_in_args = class="st">"dtype" in kwargs
3675
3676 if not dtype_present_in_args:
3677 for arg in args:
3678 if isinstance(arg, torch.dtype):
3679 dtype_present_in_args = True
3680 break
3681
3682 if getattr(self, class="st">"quantization_method", None) == QuantizationMethod.HQQ:
3683 from hqq.core.quantize import HQQLinear
3684
3685 class="cm"># Since HQQLinear stores some tensors in the class="st">'meta' attribute, we must
3686 class="cm"># explicitly move the parameters to the target device for each HQQLinear layer after `to`.
3687 super().to(*args, **kwargs)
3688 for module in self.modules():
3689 if isinstance(module, HQQLinear):
3690 if class="st">"device" in kwargs:
3691 device = kwargs[class="st">"device"]
3692 else:
3693 device = args[0]
3694 if class="st">"dtype" in kwargs:
3695 dtype = kwargs[class="st">"dtype"]
3696 elif dtype_present_in_args:
3697 dtype = arg
3698 else:
3699 dtype = None
3700 class="cm"># Due to the current messy implementation of HQQLinear, updating `compute_dtype`
3701 class="cm"># followed by calling the `cuda` method achieves the intended behavior of `to`,
3702 class="cm"># even when the target device is CPU.
3703 if dtype is not None:
3704 module.compute_dtype = dtype
3705 module.cuda(device)
3706 return self
3707
3708 if dtype_present_in_args and getattr(self, class="st">"quantization_method", None) == QuantizationMethod.QUARK:
3709 raise ValueError(class="st">"Casting a Quark quantized model to a new `dtype` is not supported.")
3710
3711 class="cm"># Checks if the model has been loaded in 4-bit or 8-bit with BNB
3712 if getattr(self, class="st">"quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
3713 if dtype_present_in_args:
3714 raise ValueError(
3715 class="st">"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
3716 class="st">" desired `dtype` by passing the correct `dtype` argument."
3717 )
3718
3719 if getattr(self, class="st">"is_loaded_in_8bit", False) and not is_bitsandbytes_available(class="st">"0.48"):
3720 raise ValueError(
3721 class="st">"You need to install `pip install bitsandbytes>=0.48.0` if you want to move a 8-bit model across devices using to()."
3722 )
3723 elif getattr(self, class="st">"quantization_method", None) == QuantizationMethod.GPTQ:
3724 if dtype_present_in_args:
3725 raise ValueError(
3726 class="st">"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
3727 class="st">" `dtype` by passing the correct `dtype` argument."
3728 )

Callers 15

setup_benchmarkMethod · 0.45
run_benchmarkFunction · 0.45
load_state_dictFunction · 0.45
invert_attention_maskMethod · 0.45
to_4dMethod · 0.45
_make_causal_maskMethod · 0.45
_expand_maskMethod · 0.45

Calls 2

cudaMethod · 0.45