diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 580b72ee119f..021e6317bc51 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -158,7 +158,8 @@ def __init__(self, max_live_parameters=max_live_parameters, param_persistence_threshold=param_persistence_threshold, model_persistence_threshold=model_persistence_threshold, - offload_param_config=offload_optimizer_config) + offload_param_config=offload_optimizer_config, + mpu=mpu) self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config)