Skip to content

Commit 535a908

Browse files
authored
Set tp world size to 1 in ckpt load, if MPU is not provided (#5243)
If MPU is not provided, set the tp world size to 1 when loading the (universal) ckpt.
1 parent 74910a9 commit 535a908

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2303,7 +2303,11 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
23032303
self._load_global_state(optim_sd)
23042304

23052305
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
2306-
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
2306+
if self.mpu is None:
2307+
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
2308+
tp_world_size = 1
2309+
else:
2310+
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
23072311
else self.mpu.get_tensor_model_parallel_world_size()
23082312

23092313
for i, _ in enumerate(self.optimizer.param_groups):

0 commit comments

Comments
 (0)