Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 0 additions & 200 deletions deepspeed/pt/deepspeed_cpu_adam.py

This file was deleted.

19 changes: 13 additions & 6 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from deepspeed.pt.fp16_optimizer import FP16_Optimizer
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
from deepspeed.pt.deepspeed_cpu_adam import CPUAdam
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS

Expand Down Expand Up @@ -170,6 +169,8 @@ def __init__(self,
self.optimizer = None
self.lr_scheduler = None
if model_parameters or optimizer:
if "torch.optim" in self.optimizer_name():
self.zero_set_cpu_offload()
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler)
self._report_progress(0)
Expand Down Expand Up @@ -263,10 +264,13 @@ def sparse_gradients_enabled(self):
def train_batch_size(self):
return self._config.train_batch_size


def train_micro_batch_size_per_gpu(self):
return self._config.train_micro_batch_size_per_gpu

def optimizer_name(self):
print("===========================")
print(self._config.optimizer_name)
return self._config.optimizer_name

def optimizer_params(self):
Expand Down Expand Up @@ -296,6 +300,9 @@ def zero_overlap_comm(self):
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload

def zero_set_cpu_offload(self):
self._config.zero_config.cpu_offload = True

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -498,9 +505,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
#jie:
if self.zero_cpu_offload():
optimizer_parameters = self.optimizer_params()
basic_optimizer = CPUAdam(client_optimizer.param_groups,
basic_optimizer = torch.optim.Adam(client_optimizer.param_groups,
**optimizer_parameters)
logger.info('Using CPU Optimizer as basic optimizer')
logger.info('Using CPU Optimizer as basic optimizer'
elif client_optimizer is not None:
basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer')
Expand Down Expand Up @@ -542,9 +549,6 @@ def _configure_basic_optimizer(self, model_parameters):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
optimizer = CPUAdam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
Expand Down Expand Up @@ -958,6 +962,9 @@ def _get_optimizer_param(self, param_name):
def get_lr(self):
return self._get_optimizer_param('lr')

def get_type(self):
return self._get_optimizer_param('type')

def get_mom(self):
return self._get_optimizer_param('betas')

Expand Down
Loading