MCPcopy
hub / github.com/huggingface/transformers / cuda

Method cuda

src/transformers/modeling_utils.py:3645–3668  ·  view source on GitHub ↗
(self, *args, **kwargs)

Source from the content-addressed store, hash-verified

3643
3644 @wraps(torch.nn.Module.cuda)
3645 def cuda(self, *args, **kwargs):
3646 if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
3647 from hqq.core.quantize import HQQLinear
3648
3649 # Since HQQLinear stores some tensors in the 'meta' attribute,
3650 # it's necessary to manually call the `cuda` method on HQQLinear layers.
3651 super().cuda(*args, **kwargs)
3652 for module in self.modules():
3653 if isinstance(module, HQQLinear):
3654 if len(args) > 0:
3655 device = args[0]
3656 else:
3657 device = kwargs.get("device", "cuda")
3658 module.cuda(device)
3659 return self
3660
3661 # Checks if the model has been loaded in 4-bit or 8-bit with BNB
3662 if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
3663 if getattr(self, "is_loaded_in_8bit", False):
3664 raise ValueError(
3665 "Calling `cuda()` is not supported for `8-bit` quantized models. "
3666 " Please use the model as it is, since the model has already been set to the correct devices."
3667 )
3668 return super().cuda(*args, **kwargs)
3669
3670 @wraps(torch.nn.Module.to)
3671 def to(self, *args, **kwargs):

Callers 5

toMethod · 0.45
_quantizeMethod · 0.45
generate_without_cbFunction · 0.45

Calls 1

getMethod · 0.45

Tested by

no test coverage detected