diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 13d58ff11438..e245b6f8f1e6 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -20,6 +20,7 @@ import deepspeed from deepspeed.runtime.utils import see_memory_usage, DummyOptim +from .zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException @@ -646,7 +647,9 @@ def zero_offload_param(self): return self._config.zero_config.offload_param def zero_cpu_offload(self): - return self._config.zero_config.offload_optimizer is not None + if self._config.zero_config.offload_optimizer is not None: + return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu + return False def zero_sub_group_size(self): return self._config.zero_config.sub_group_size