Skip to content

Commit

Permalink
fix sequence parallel(Ulysses) grad scale for zero0 (#5555)
Browse files Browse the repository at this point in the history
use dp_world_size for grad reduction, instead of seq_dp_world_size.
Currently, for zero0, only sparse tensors use the correct world_size.

tiny model with sp=4 grad norm test:
grad_norm | step1 | step2 | step3 | step4 |step5 | step100
-- | -- | -- | -- | -- | --| --
zero1 | 15.825 | 16.646|15.853 | 16.159 | 17.333 | 15.555
zero0 | 3.956 | 4.161 | 3.963 | 4.040 | 4.333| 3.889
zero0(this patch) | 15.825 | 16.646 | 15.853| 16.159 | 17.333 | 15.554
  • Loading branch information
inkcherry authored Jun 5, 2024
1 parent af4356b commit 6b6d641
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,18 +2407,22 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads)
if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
dp_world_size = dist.get_world_size(dp_group)
else:
dp_group = groups._get_sequence_data_parallel_group()

dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size)
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets):
if sparse_bucket_tuple:
bucket_type, sparse_bucket = sparse_bucket_tuple
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group)
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size)

for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets):
if dense_bucket_tuple:
bucket_type, dense_bucket = dense_bucket_tuple
self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
self.allreduce_no_retain(dense_bucket,
dp_group=dp_group,
numel_per_bucket=elements_per_buffer,
dp_world_size=dp_world_size)

def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
# to maintain the gradients value unaffected by ep_size setting,
Expand Down Expand Up @@ -2490,9 +2494,9 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None):
dp_world_size = dist.get_world_size(group=dp_group)
if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size)))
values.mul_(self.gradient_predivide_factor() / (dp_world_size))
else:
values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size)))
values.mul_(1. / (dp_world_size))

indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
Expand Down

0 comments on commit 6b6d641

Please sign in to comment.