From c101885e70cdc6124911d67e10ae913361cf3140 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 28 Jun 2022 04:28:25 +0500 Subject: [PATCH 1/3] fix partition id for cpu offload --- deepspeed/runtime/zero/stage_1_and_2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c6a2e26aac8f..44bbaa5b9a97 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1747,12 +1747,13 @@ def step(self, closure=None): from deepspeed.ops.adam import DeepSpeedCPUAdam if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: bit16_param_groups = [[ - bit16_partitions[partition_id] - ] for bit16_partitions in self.parallel_partitioned_bit16_groups] + bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_id])] + ] for group_id, bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)] self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) else: self.optimizer.step() From feb7270c31a2218e613a37c119a59c63300640cf Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 28 Jun 2022 04:38:51 +0500 Subject: [PATCH 2/3] formatting changes --- deepspeed/runtime/zero/stage_1_and_2.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 44bbaa5b9a97..b1f9dd2f1c7d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1746,14 +1746,19 @@ def step(self, closure=None): if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: - bit16_param_groups = [[ - bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_id])] - ] for group_id, bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)] + bit16_param_groups = [ + [ + bit16_partitions[dist.get_rank( + group=self.real_dp_process_group[group_id])] + ] for group_id, + bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups) + ] self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): - partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + partition_id = dist.get_rank( + group=self.real_dp_process_group[group_id]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) else: self.optimizer.step() From a037f10083155b3c1c32acf7e634d7e2786495a1 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 7 Jul 2022 00:45:47 +0500 Subject: [PATCH 3/3] corrected bug in torch cpu adam code path --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index b1f9dd2f1c7d..7d5cc8865225 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1756,7 +1756,7 @@ def step(self, closure=None): self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() - for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): + for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank( group=self.real_dp_process_group[group_id]) bit16_partitions[partition_id].data.copy_(fp32_partition.data)