Skip to content
16 changes: 11 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,14 +1750,20 @@ def step(self, closure=None):
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
bit16_param_groups = [[
bit16_partitions[partition_id]
] for bit16_partitions in self.parallel_partitioned_bit16_groups]
bit16_param_groups = [
[
bit16_partitions[dist.get_rank(
group=self.real_dp_process_group[group_id])]
] for group_id,
bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)
]
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
# after step(), single_partition_of_fp32_groups has the local optimizer's own partition of updated params
for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups):
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])

bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()
Expand Down