Calculate model's total param count. If trainable_only is True then count only those requiring grads.
(model, trainable_only=False)
| 1007 | |
| 1008 | |
| 1009 | def get_model_param_count(model, trainable_only=False): |
| 1010 | """ |
| 1011 | Calculate model's total param count. If trainable_only is True then count only those requiring grads. |
| 1012 | """ |
| 1013 | if is_deepspeed_zero3_enabled(): |
| 1014 | |
| 1015 | def numel(p): |
| 1016 | return p.ds_numel if hasattr(p, "ds_numel") else p.numel() |
| 1017 | |
| 1018 | else: |
| 1019 | |
| 1020 | def numel(p): |
| 1021 | return p.numel() |
| 1022 | |
| 1023 | return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) |
| 1024 | |
| 1025 | |
| 1026 | def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None): |
no test coverage detected