diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index efd8d0fc3a4e..9549822d92fc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -97,6 +97,7 @@ distributed_broadcast_scalars, distributed_concat, find_batch_size, + get_model_param_count, get_module_class_from_name, get_parameter_names, nested_concat, @@ -1744,9 +1745,7 @@ def _inner_training_loop( logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") - logger.info( - f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" - ) + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}") self.state.epoch = 0 start_time = time.time() diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index eefbb5268345..dee1dce0f6f7 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -35,6 +35,7 @@ from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data.distributed import DistributedSampler +from .deepspeed import is_deepspeed_zero3_enabled from .tokenization_utils_base import BatchEncoding from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging @@ -1032,6 +1033,23 @@ def save_state(self): self.state.save_to_json(path) +def get_model_param_count(model, trainable_only=False): + """ + Calculate model's total param count. If trainable_only is True then count only those requiring grads + """ + if is_deepspeed_zero3_enabled(): + + def numel(p): + return p.ds_numel + + else: + + def numel(p): + return p.numel() + + return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) + + def get_parameter_names(model, forbidden_layer_types): """ Returns the names of the model parameters that are not inside a forbidden layer.