Skip to content

Commit 8a63754

Browse files
zarzenZhen Zhang
and
Zhen Zhang
authored
save_non_zero_checkpoint on first partition group (#3787)
Co-authored-by: Zhen Zhang <[email protected]>
1 parent 82c498d commit 8a63754

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

deepspeed/runtime/engine.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,13 @@ def zero_optimization_partition_gradients(self):
719719
def zero_optimization_partition_weights(self):
720720
return self.zero_optimization_stage() >= ZeroStageEnum.weights
721721

722+
def is_first_weights_partition_group(self):
723+
ret = True if self.mics_shard_size() < 0 \
724+
and self.zero_optimization_partition_weights() else False
725+
if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size():
726+
ret = True
727+
return ret
728+
722729
def zero_contiguous_gradients(self):
723730
return self._config.zero_config.contiguous_gradients
724731

@@ -898,7 +905,8 @@ def _configure_checkpointing(self, dist_init_required):
898905
# only the first data parallel process needs to store the model checkpoint
899906
# if you want to use node local storage this must be done by rank 0 on each
900907
# node
901-
self.save_non_zero_checkpoint = (rank == 0) or self.zero_optimization_partition_weights()
908+
self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()
909+
and self.is_first_weights_partition_group())
902910

903911
if self.zero_optimization() or self.bfloat16_enabled():
904912
param_rank = dist.get_rank(group=self.optimizer.dp_process_group)

tests/unit/checkpoint/test_mics_optimizer.py

+14
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,17 @@ def test_not_load_optimizer_state(self, tmpdir, shard_size):
6464
def test_load_module_only(self, tmpdir, shard_size):
6565
config_dict, hidden_dim, models = self._toy_model_config(shard_size)
6666
checkpoint_correctness_verification(config_dict, models, hidden_dim, tmpdir, load_module_only=True)
67+
68+
@pytest.mark.parametrize('shard_size', [1, 2, 4])
69+
def test_save_checkpoint_on_first_partition_group(self, tmpdir, shard_size):
70+
config_dict, _, models = self._toy_model_config(shard_size)
71+
ds_engine, _, _, _ = deepspeed.initialize(config=config_dict,
72+
model=models[0],
73+
model_parameters=models[0].parameters(),
74+
optimizer=None)
75+
76+
ds_engine.save_checkpoint(tmpdir)
77+
if ds_engine.global_rank < shard_size:
78+
assert ds_engine.save_non_zero_checkpoint == True
79+
else:
80+
assert ds_engine.save_non_zero_checkpoint == False

0 commit comments

Comments
 (0)