Temporarily register hidden kernel wrappers so `kernelize` can discover and replace them.
(self, mode=None)
| 4646 | self._loss_function = value |
| 4647 | |
| 4648 | def kernelize(self, mode=None): |
| 4649 | """Temporarily register hidden kernel wrappers so `kernelize` can discover and replace them.""" |
| 4650 | if not is_kernels_available(): |
| 4651 | raise ValueError( |
| 4652 | "Kernels are not available. To use kernels, please install kernels using `pip install -U kernels`" |
| 4653 | ) |
| 4654 | from kernels import Device, Mode, kernelize |
| 4655 | |
| 4656 | def attach_hidden_kernels(module): |
| 4657 | for name, fn in getattr(module, "_hidden_kernels", {}).items(): |
| 4658 | if name not in dict(module.named_children()): |
| 4659 | if not isinstance(fn, nn.Module): |
| 4660 | raise ValueError( |
| 4661 | f"Attempted to register a kernel for {name}, but it was not a `torch.nn.Module`. " |
| 4662 | "This means the underlying function needs to be decorated with `@use_kernel_func_from_hub`. " |
| 4663 | "Please submit and issue to the transformers repo: `https://github.com/huggingface/transformers/issues`." |
| 4664 | ) |
| 4665 | module.register_module(name, fn) |
| 4666 | |
| 4667 | def detach_hidden_kernels(module): |
| 4668 | for name in getattr(module, "_hidden_kernels", {}): |
| 4669 | # Skip deregistering if it failed to properly register, |
| 4670 | # i.e. `ValueError` will be raised afterwards |
| 4671 | if hasattr(module, name): |
| 4672 | delattr(module, name) |
| 4673 | |
| 4674 | try: |
| 4675 | self.apply(attach_hidden_kernels) |
| 4676 | |
| 4677 | mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode |
| 4678 | if self.kernel_config is not None: |
| 4679 | from kernels import use_kernel_mapping |
| 4680 | |
| 4681 | inherit_mapping = not self.kernel_config.use_local_kernel |
| 4682 | with use_kernel_mapping(self.kernel_config.kernel_mapping, inherit_mapping=inherit_mapping): |
| 4683 | kernelize(self, device=Device(type=self.device.type), mode=mode) |
| 4684 | else: |
| 4685 | kernelize(self, device=Device(type=self.device.type), mode=mode) |
| 4686 | self._use_kernels = True |
| 4687 | |
| 4688 | finally: |
| 4689 | self.apply(detach_hidden_kernels) |
| 4690 | |
| 4691 | @property |
| 4692 | def use_kernels(self) -> bool: |
no test coverage detected