diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 20f92956a3a5..211e9dac7380 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1750,14 +1750,20 @@ def step(self, closure=None): if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: - bit16_param_groups = [[ - bit16_partitions[partition_id] - ] for bit16_partitions in self.parallel_partitioned_bit16_groups] + bit16_param_groups = [ + [ + bit16_partitions[dist.get_rank( + group=self.real_dp_process_group[group_id])] + ] for group_id, + bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups) + ] self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() - # after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params - for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): + for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): + partition_id = dist.get_rank( + group=self.real_dp_process_group[group_id]) + bit16_partitions[partition_id].data.copy_(fp32_partition.data) else: self.optimizer.step()