Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
10 changes: 5 additions & 5 deletions deepspeed/pt/deepspeed_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def reset():
size_offsets = []


def _configure_using_config_file(deepspeed_config):
def _configure_using_config_file(deepspeed_config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME

config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config
config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config
logger.info(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
Expand Down Expand Up @@ -684,12 +684,12 @@ def configure(

_configure_defaults()

if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config)

if mpu_ is not None:
mpu = mpu_

if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config, mpu=mpu)

if partition_activations is not None:
PARTITION_ACTIVATIONS = partition_activations

Expand Down
4 changes: 3 additions & 1 deletion deepspeed/pt/deepspeed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,10 +457,12 @@ def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)

assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)

assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format(
assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
GRADIENT_ACCUMULATION_STEPS)

def _do_warning_check(self):
Expand Down
24 changes: 17 additions & 7 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def __init__(self,
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()

self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler
Expand Down Expand Up @@ -293,6 +292,9 @@ def zero_reduce_scatter(self):
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm

def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -492,6 +494,7 @@ def _configure_distributed_model(self, model):

# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):

if client_optimizer is not None:
basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer')
Expand All @@ -505,13 +508,14 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER:
if self.optimizer_name() not in [ADAM_OPTIMIZER]:
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'

logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)

self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
Expand All @@ -523,8 +527,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer

# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
Expand All @@ -533,8 +537,11 @@ def _configure_basic_optimizer(self, model_parameters):
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)
if self.optimizer_name() == ADAM_OPTIMIZER:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
if self.zero_cpu_offload():
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
else:
Expand Down Expand Up @@ -613,6 +620,7 @@ def _configure_zero_optimizer(self, optimizer):
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor())
Expand Down Expand Up @@ -843,7 +851,6 @@ def step(self):
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())

self.optimizer.step()

#zero grad in basic optimizer could be unreliable and may not exhibit
Expand Down Expand Up @@ -946,6 +953,9 @@ def _get_optimizer_param(self, param_name):
def get_lr(self):
return self._get_optimizer_param('lr')

def get_type(self):
return self._get_optimizer_param('type')

def get_mom(self):
return self._get_optimizer_param('betas')

Expand Down
22 changes: 15 additions & 7 deletions deepspeed/pt/deepspeed_zero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
}
}
'''
Expand Down Expand Up @@ -63,21 +64,22 @@
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True

ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
}


Expand All @@ -93,6 +95,7 @@ def __init__(self, param_dict):
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
self.cpu_offload = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
Expand Down Expand Up @@ -157,7 +160,12 @@ def _initialize(self, zero_config_dict):
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)

self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)

self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
Loading