Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import ops

from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
Expand Down
21 changes: 13 additions & 8 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
weight_decay=weight_decay,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)

self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1

Expand All @@ -43,33 +44,37 @@ def step(self, closure=None, fp16_param_groups=None):
with torch.enable_grad():
loss = closure()

for i, group in enumerate(self.param_groups):
for gid, p in enumerate(group['params']):
for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):

if p.grad is None:
continue

grad = p.grad
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, device='cpu')
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, device='cpu')
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1

if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[i][gid]
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
else:
ds_opt_adam.adam_update(self.opt_id, p, grad, exp_avg, exp_avg_sq)
ds_opt_adam.adam_update(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq)
return loss
3 changes: 2 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
DEEPSPEED_ADAM = 'deepspeed_adam'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM]


def get_amp_enabled(param_dict):
Expand Down
17 changes: 5 additions & 12 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, DEEPSPEED_ADAM, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
Expand Down Expand Up @@ -296,9 +295,6 @@ def zero_overlap_comm(self):
def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload

def deepspeed_adam(self):
return self._config.zero_config.deepspeed_adam

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -539,16 +535,13 @@ def _configure_basic_optimizer(self, model_parameters):
)
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
if False: #self.deepspeed_adam():
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters)
else:
optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)

optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == DEEPSPEED_ADAM:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
else:
Expand Down
6 changes: 4 additions & 2 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,15 +1408,17 @@ def step(self, closure=None):
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
#torch.set_num_threads(12)
timers('optimizer_step').start()
self.optimizer.step() #fp16_param_groups=self.parallel_partitioned_fp16_groups)

#self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None

for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)

timers('optimizer_step').stop()

if self.cpu_offload:
Expand Down