Get all parameter names that weight decay will be applied to. This function filters out parameters in two ways: 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) 2. By parameter name patterns (containing 'bias', or variation of 'norm')
(self, model: nn.Module)
| 1287 | return handler(ctx) |
| 1288 | |
| 1289 | def get_decay_parameter_names(self, model: nn.Module) -> list[str]: |
| 1290 | """ |
| 1291 | Get all parameter names that weight decay will be applied to. |
| 1292 | |
| 1293 | This function filters out parameters in two ways: |
| 1294 | 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) |
| 1295 | 2. By parameter name patterns (containing 'bias', or variation of 'norm') |
| 1296 | """ |
| 1297 | forbidden_name_patterns = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"] |
| 1298 | decay_parameters = get_parameter_names(model, [nn.LayerNorm], forbidden_name_patterns) |
| 1299 | return decay_parameters |
| 1300 | |
| 1301 | def _get_learning_rate(self) -> float: |
| 1302 | """ |
no test coverage detected