diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index dd995b9bfe70..b57770f33b29 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1043,7 +1043,7 @@ def get_model_param_count(model, trainable_only=False): if is_deepspeed_zero3_enabled(): def numel(p): - return p.ds_numel + return p.ds_numel if hasattr(p, "ds_numel") else p.numel() else: