Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5f63fc8
cpu-offload
jren73 Aug 4, 2020
e01238b
update
jren73 Aug 4, 2020
73b956b
updte
jren73 Aug 4, 2020
98deb70
deleted: deepspeed/pt/deepspeed_zero_optimizer_cpuoffload.py
jren73 Aug 6, 2020
e3b2a42
modified: deepspeed/pt/deepspeed_zero_optimizer.py
jren73 Aug 7, 2020
f832a2e
update
jren73 Aug 10, 2020
004884b
modified: deepspeed/pt/deepspeed_cpu_adam.py
jren73 Aug 10, 2020
0effd77
deleted: install_output.txt
jren73 Aug 10, 2020
af3b834
modified: deepspeed/pt/fp16_unfused_optimizer.py
jren73 Aug 10, 2020
e2d936d
Merge pull request #2 from jren73/ZeRO-2-cpu_offload
jren73 Aug 10, 2020
ef5c785
modified: deepspeed/pt/deepspeed_cpu_adam.py
jren73 Aug 11, 2020
7f0a856
Merge pull request #3 from jren73/ZeRO-2-cpu_offload
jren73 Aug 11, 2020
6e45e8b
modified: deepspeed/pt/deepspeed_zero_optimizer.py
jren73 Aug 11, 2020
e930604
Merge pull request #4 from jren73/ZeRO-2-cpu_offload
jren73 Aug 11, 2020
d2cc800
Merge branch 'master' into master
jeffra Aug 11, 2020
f8812b9
modified: deepspeed/pt/deepspeed_cpu_adam.py
jren73 Aug 11, 2020
d1a435c
Merge pull request #5 from jren73/ZeRO-2-cpu_offload
jren73 Aug 11, 2020
6415738
deleted: deepspeed_cpu_adam.py
jren73 Aug 17, 2020
7eb6041
Merge pull request #6 from jren73/ZeRO-2-cpu_offload
jren73 Aug 17, 2020
fbd79c6
modified: deepspeed/pt/deepspeed_light.py
jren73 Aug 17, 2020
f1a180f
Merge pull request #7 from jren73/ZeRO-2-cpu_offload
jren73 Aug 18, 2020
5181c60
modified: deepspeed/pt/deepspeed_light.py
jren73 Aug 18, 2020
a2b7433
Merge pull request #8 from jren73/ZeRO-2-cpu_offload
jren73 Aug 18, 2020
41f18d1
modified: deepspeed/pt/deepspeed_config.py
jren73 Aug 24, 2020
3835b22
Merge pull request #9 from jren73/ZeRO-2-cpu_offload
jren73 Aug 24, 2020
ffad985
modified: deepspeed/pt/deepspeed_checkpointing.py
jren73 Aug 28, 2020
1e46c91
Merge pull request #12 from jren73/ZeRO-2-cpu_offload
jren73 Aug 28, 2020
646f709
update DSE to ZeRO-Offload commit
jren73 Aug 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
10 changes: 5 additions & 5 deletions deepspeed/pt/deepspeed_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def reset():
size_offsets = []


def _configure_using_config_file(deepspeed_config):
def _configure_using_config_file(deepspeed_config, mpu=None):
global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME

config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config
config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config
logger.info(config.repr())
PARTITION_ACTIVATIONS = config.partition_activations
CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization
Expand Down Expand Up @@ -684,12 +684,12 @@ def configure(

_configure_defaults()

if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config)

if mpu_ is not None:
mpu = mpu_

if deepspeed_config is not None:
_configure_using_config_file(deepspeed_config, mpu=mpu)

if partition_activations is not None:
PARTITION_ACTIVATIONS = partition_activations

Expand Down
11 changes: 9 additions & 2 deletions deepspeed/pt/deepspeed_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
TENSOR_CORE_ALIGN_SIZE = 8
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
TORCH_ADAM_OPTIMIZER = 'torch_adam'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER]


def get_amp_enabled(param_dict):
Expand Down Expand Up @@ -457,12 +458,18 @@ def _do_error_check(self):
if self.zero_enabled:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)

assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)

assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format(
assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
GRADIENT_ACCUMULATION_STEPS)

if self.optimizer_name == TORCH_ADAM_OPTIMIZER:
assert self.zero_enabled, "ZeRO is not enabled with using TORCH_ADAM_OPTIMIZER"
assert self.zero_config.cpu_offload, " cpu_offload is not enabled with using TORCH_ADAM_OPTIMIZER"

def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled

Expand Down
28 changes: 22 additions & 6 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.pt.deepspeed_fused_lamb import FusedLamb
from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS

from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader
from deepspeed.pt.deepspeed_constants import \
Expand Down Expand Up @@ -107,7 +107,6 @@ def __init__(self,
collate_fn=None,
config_params=None):
super(DeepSpeedLight, self).__init__()

self.client_optimizer = optimizer
self.client_model_parameters = model_parameters
self.client_lr_scheduler = lr_scheduler
Expand Down Expand Up @@ -293,6 +292,9 @@ def zero_reduce_scatter(self):
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm

def zero_cpu_offload(self):
return self._config.zero_config.cpu_offload

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -492,6 +494,7 @@ def _configure_distributed_model(self, model):

# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):

if client_optimizer is not None:
basic_optimizer = client_optimizer
logger.info('Using client Optimizer as basic optimizer')
Expand All @@ -505,13 +508,21 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() != ADAM_OPTIMIZER:
if self.optimizer_name() not in [ADAM_OPTIMIZER, TORCH_ADAM_OPTIMIZER]:
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'

logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
if self.zero_cpu_offload():
if self.optimizer_name() != TORCH_ADAM_OPTIMIZER:
assert self.zero_allow_untested_optimizer(), \
'You are using ZeRO-Offload with an untested Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'

logger.warning(
"**** You are using ZeRO-Offload with an untested optimizer, proceed with caution *****"
)
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
Expand All @@ -523,8 +534,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
else:
self.optimizer = basic_optimizer

# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
Expand All @@ -537,6 +548,8 @@ def _configure_basic_optimizer(self, model_parameters):
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == TORCH_ADAM_OPTIMIZER:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
Expand Down Expand Up @@ -613,6 +626,7 @@ def _configure_zero_optimizer(self, optimizer):
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor())
Expand Down Expand Up @@ -843,7 +857,6 @@ def step(self):
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())

self.optimizer.step()

#zero grad in basic optimizer could be unreliable and may not exhibit
Expand Down Expand Up @@ -946,6 +959,9 @@ def _get_optimizer_param(self, param_name):
def get_lr(self):
return self._get_optimizer_param('lr')

def get_type(self):
return self._get_optimizer_param('type')

def get_mom(self):
return self._get_optimizer_param('betas')

Expand Down
22 changes: 15 additions & 7 deletions deepspeed/pt/deepspeed_zero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"overlap_comm": [true|false],
"reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
"cpu_offload": [true|false]
}
}
'''
Expand Down Expand Up @@ -63,21 +64,22 @@
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True

ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False

ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
}


Expand All @@ -93,6 +95,7 @@ def __init__(self, param_dict):
self.allgather_bucket_size = None
self.overlap_comm = None
self.load_from_fp32_weights = None
self.cpu_offload = None

if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
Expand Down Expand Up @@ -157,7 +160,12 @@ def _initialize(self, zero_config_dict):
zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)

self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)

self.cpu_offload = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_CPU_OFFLOAD,
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)
Loading