diff --git a/deepspeed/pt/zero_optimizer_stage1.py b/deepspeed/pt/zero_optimizer_stage1.py index d527a17123d7..82812bccdb3a 100755 --- a/deepspeed/pt/zero_optimizer_stage1.py +++ b/deepspeed/pt/zero_optimizer_stage1.py @@ -40,8 +40,7 @@ def flatten_dense_tensors_sub_partition_aligned(tensor_list, dp, max_elements_per_comm, pg): - assert (max_elements_per_comm >= dp, - f"max_elements_per_comm {max_elements_per_comm} < dp {dp}") + assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}" num_elements = sum(t.numel() for t in tensor_list) log_dist("Total number of elements in model: {}, max elements per com: {}".format(