diff --git a/deepspeed/pt/deepspeed_checkpointing.py b/deepspeed/pt/deepspeed_checkpointing.py index 2a5bb2ab688b..746adae1c599 100755 --- a/deepspeed/pt/deepspeed_checkpointing.py +++ b/deepspeed/pt/deepspeed_checkpointing.py @@ -602,11 +602,11 @@ def reset(): size_offsets = [] -def _configure_using_config_file(deepspeed_config): +def _configure_using_config_file(deepspeed_config, mpu=None): global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME - config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config + config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config logger.info(config.repr()) PARTITION_ACTIVATIONS = config.partition_activations CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization @@ -684,12 +684,12 @@ def configure( _configure_defaults() - if deepspeed_config is not None: - _configure_using_config_file(deepspeed_config) - if mpu_ is not None: mpu = mpu_ + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config, mpu=mpu) + if partition_activations is not None: PARTITION_ACTIVATIONS = partition_activations