MCPcopy
hub / github.com/huggingface/transformers / load_sharded_checkpoint

Function load_sharded_checkpoint

src/transformers/trainer_utils.py:1057–1130  ·  view source on GitHub ↗

This is the same as [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) but for a sharded checkpoint. This load is performed efficiently: each checkpoint shard is loaded one by

(model, folder, strict=True, prefer_safe=True)

Source from the content-addressed store, hash-verified

1055
1056
1057def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
1058 """
1059 This is the same as
1060 [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
1061 but for a sharded checkpoint.
1062
1063 This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
1064 loaded in the model.
1065
1066 Args:
1067 model (`torch.nn.Module`): The model in which to load the checkpoint.
1068 folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
1069 strict (`bool`, *optional*, defaults to `True`):
1070 Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
1071 prefer_safe (`bool`, *optional*, defaults to `True`):
1072 If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
1073 safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
1074
1075 Returns:
1076 `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
1077 - `missing_keys` is a list of str containing the missing keys
1078 - `unexpected_keys` is a list of str containing the unexpected keys
1079 """
1080 # Load the index
1081 index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
1082 safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
1083
1084 index_present = os.path.isfile(index_file)
1085 safe_index_present = os.path.isfile(safe_index_file)
1086
1087 if not index_present and not safe_index_present:
1088 filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
1089 raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
1090
1091 load_safe = safe_index_present and (prefer_safe or not index_present)
1092 load_index = safe_index_file if load_safe else index_file
1093
1094 with open(load_index, "r", encoding="utf-8") as f:
1095 index = json.load(f)
1096
1097 shard_files = list(set(index["weight_map"].values()))
1098
1099 # If strict=True, error before loading any of the state dicts.
1100 # TODO: Here, update the weight map with the config.dynamic_weight_conversion
1101 loaded_keys = index["weight_map"].keys()
1102 model_keys = model.state_dict().keys()
1103 missing_keys = [key for key in model_keys if key not in loaded_keys]
1104 unexpected_keys = [key for key in loaded_keys if key not in model_keys]
1105 if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
1106 error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
1107 if len(missing_keys) > 0:
1108 str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
1109 error_message += f"\nMissing key(s): {str_missing_keys}."
1110 if len(unexpected_keys) > 0:
1111 str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
1112 error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
1113 raise RuntimeError(error_message)
1114

Callers 2

_load_from_checkpointMethod · 0.85
_load_best_modelMethod · 0.85

Calls 7

check_torch_load_is_safeFunction · 0.85
joinMethod · 0.80
collectMethod · 0.80
valuesMethod · 0.45
keysMethod · 0.45
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected