This is the same as [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) but for a sharded checkpoint. This load is performed efficiently: each checkpoint shard is loaded one by
(model, folder, strict=True, prefer_safe=True)
| 1055 | |
| 1056 | |
| 1057 | def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): |
| 1058 | """ |
| 1059 | This is the same as |
| 1060 | [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) |
| 1061 | but for a sharded checkpoint. |
| 1062 | |
| 1063 | This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being |
| 1064 | loaded in the model. |
| 1065 | |
| 1066 | Args: |
| 1067 | model (`torch.nn.Module`): The model in which to load the checkpoint. |
| 1068 | folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. |
| 1069 | strict (`bool`, *optional*, defaults to `True`): |
| 1070 | Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. |
| 1071 | prefer_safe (`bool`, *optional*, defaults to `True`): |
| 1072 | If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the |
| 1073 | safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. |
| 1074 | |
| 1075 | Returns: |
| 1076 | `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields |
| 1077 | - `missing_keys` is a list of str containing the missing keys |
| 1078 | - `unexpected_keys` is a list of str containing the unexpected keys |
| 1079 | """ |
| 1080 | # Load the index |
| 1081 | index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) |
| 1082 | safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) |
| 1083 | |
| 1084 | index_present = os.path.isfile(index_file) |
| 1085 | safe_index_present = os.path.isfile(safe_index_file) |
| 1086 | |
| 1087 | if not index_present and not safe_index_present: |
| 1088 | filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) |
| 1089 | raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") |
| 1090 | |
| 1091 | load_safe = safe_index_present and (prefer_safe or not index_present) |
| 1092 | load_index = safe_index_file if load_safe else index_file |
| 1093 | |
| 1094 | with open(load_index, "r", encoding="utf-8") as f: |
| 1095 | index = json.load(f) |
| 1096 | |
| 1097 | shard_files = list(set(index["weight_map"].values())) |
| 1098 | |
| 1099 | # If strict=True, error before loading any of the state dicts. |
| 1100 | # TODO: Here, update the weight map with the config.dynamic_weight_conversion |
| 1101 | loaded_keys = index["weight_map"].keys() |
| 1102 | model_keys = model.state_dict().keys() |
| 1103 | missing_keys = [key for key in model_keys if key not in loaded_keys] |
| 1104 | unexpected_keys = [key for key in loaded_keys if key not in model_keys] |
| 1105 | if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): |
| 1106 | error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" |
| 1107 | if len(missing_keys) > 0: |
| 1108 | str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) |
| 1109 | error_message += f"\nMissing key(s): {str_missing_keys}." |
| 1110 | if len(unexpected_keys) > 0: |
| 1111 | str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) |
| 1112 | error_message += f"\nUnexpected key(s): {str_unexpected_keys}." |
| 1113 | raise RuntimeError(error_message) |
| 1114 |
no test coverage detected