diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 175860c7a27e..8f20b390934f 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -270,6 +270,12 @@ def gather_partitioned_activations(tensors, device=None): inputs.append(item) continue + # don't need to do all_gather if model parallel size is 1 + if mp_size == 1: + item = item.view(list(size.numpy())) + inputs.append(item) + continue + partition_size = item.numel() tensor_size = partition_size * mp_size if device is not None: