From 1ebcd6c50af7bc5c41507c60961f9031c729ec8c Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 2 Sep 2020 22:05:04 -0700 Subject: [PATCH 1/4] Update test_sparse_attention.py --- tests/unit/test_sparse_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_sparse_attention.py b/tests/unit/test_sparse_attention.py index 3f2078946297..f113e3b72ecd 100644 --- a/tests/unit/test_sparse_attention.py +++ b/tests/unit/test_sparse_attention.py @@ -232,6 +232,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo def _skip_on_cuda_compatability(): + pytest.skip("Skip these tests for now until we get our docker image fixed.") if torch.cuda.get_device_capability()[0] != 7: pytest.skip("needs compute capability 7; v100") cuda_major = int(torch.version.cuda.split('.')[0]) * 10 From 2228febda4505c8f7b8e5f7fe92066db47a75d5b Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 3 Sep 2020 19:34:28 +0000 Subject: [PATCH 2/4] jren changes --- deepspeed/runtime/engine.py | 85 +++++++++++++++++--------------- deepspeed/runtime/zero/stage2.py | 77 +++++++++-------------------- 2 files changed, 70 insertions(+), 92 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 96c591bd3c2a..fa8ee6cb36e7 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2,36 +2,36 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' -import os import torch +import os import warnings import torch.distributed as dist - -from apex import amp from torch.nn.modules import Module from torch.distributed.distributed_c10d import _get_global_rank +from apex import amp + from tensorboardX import SummaryWriter -from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer -from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 -from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing -from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer -from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer -from deepspeed.runtime.config import DeepSpeedConfig, \ - ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS -from deepspeed.runtime.dataloader import DeepSpeedDataLoader -from deepspeed.runtime.constants import \ +from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer +from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer +from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 +from deepspeed.pt.log_utils import logger +import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing + +from deepspeed.pt.fp16_optimizer import FP16_Optimizer +from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer +from deepspeed.pt.deepspeed_fused_lamb import FusedLamb +from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \ + ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS + +from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader +from deepspeed.pt.deepspeed_constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - TORCH_DISTRIBUTED_DEFAULT_PORT -from deepspeed.runtime.zero.constants import \ + TORCH_DISTRIBUTED_DEFAULT_PORT, \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS -from deepspeed.runtime.csr_tensor import CSRTensor -import deepspeed.runtime.lr_schedules as lr_schedules -from deepspeed.ops.lamb import FusedLamb - -from deepspeed.utils import logger -from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer +import deepspeed.pt.deepspeed_lr_schedules as lr_schedules +from deepspeed.pt.deepspeed_csr_tensor import CSRTensor MEMORY_OPT_ALLREDUCE_SIZE = 500000000 SUMMARY_WRITER_DIR_NAME = "JobId" @@ -92,7 +92,7 @@ def print_configuration(args, name): logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg))) -class DeepSpeedEngine(Module): +class DeepSpeedLight(Module): r"""DeepSpeed engine for training. """ def __init__(self, @@ -106,7 +106,7 @@ def __init__(self, dist_init_required=None, collate_fn=None, config_params=None): - super(DeepSpeedEngine, self).__init__() + super(DeepSpeedLight, self).__init__() self.client_optimizer = optimizer self.client_model_parameters = model_parameters self.client_lr_scheduler = lr_scheduler @@ -313,6 +313,9 @@ def zero_contiguous_gradients(self): def zero_load_from_fp32_weights(self): return self._config.zero_config.load_from_fp32_weights + def allgather_size(self): + return self._config.allgather_size + def fp16_enabled(self): return self._config.fp16_enabled @@ -505,14 +508,21 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" - if self.optimizer_name() not in [ADAM_OPTIMIZER]: + if self.optimizer_name() not in [ADAM_OPTIMIZER, TORCH_ADAM_OPTIMIZER]: assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + if self.zero_cpu_offload(): + if self.optimizer_name() != TORCH_ADAM_OPTIMIZER: + assert self.zero_allow_untested_optimizer(), \ + 'You are using ZeRO-Offload with an untested Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' + logger.warning( + "**** You are using ZeRO-Offload with an untested optimizer, proceed with caution *****" + ) self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" @@ -534,13 +544,12 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) if self.optimizer_name() == ADAM_OPTIMIZER: - if self.zero_cpu_offload(): - optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) - else: - from apex.optimizers.fused_adam import FusedAdam - optimizer = FusedAdam(model_parameters, **optimizer_parameters) + from apex.optimizers.fused_adam import FusedAdam + optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: optimizer = FusedLamb(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == TORCH_ADAM_OPTIMIZER: + optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -602,6 +611,8 @@ def _configure_zero_optimizer(self, optimizer): dp_process_group=self.data_parallel_group, mpu=self.mpu) elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: + assert self.gradient_accumulation_steps( + ) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1" optimizer = FP16_DeepSpeedZeroOptimizer( optimizer, timers=self.timers, @@ -724,19 +735,15 @@ def forward(self, *inputs, **kwargs): return loss def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - - #Zero stage 2 communicates during non gradient accumulation boundaries as well - if self.zero_optimization_partition_gradients(): - self.optimizer.overlapping_partition_gradients_reduce_epilogue() - - #Communicate only at gradient accumulation boundaries - elif self.is_gradient_accumulation_boundary(): + if self.is_gradient_accumulation_boundary(): if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: assert self.zero_reduce_scatter() self.optimizer.reduce_scatter_gradients( postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_average=self.gradient_average) + elif self.zero_optimization_partition_gradients(): + self.optimizer.overlapping_partition_gradients_reduce_epilogue() else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) @@ -1022,10 +1029,10 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) # rank is reducing the same size. In some cases it may make # sense in the future to support the ability to average not # w.r.t. world size but with a different value. - param.grad = torch.zeros(param.size(), - dtype=param.dtype, - device=param.device) - grads.append(param.grad.data) + grads.append( + torch.zeros(param.size(), + dtype=param.dtype, + device=param.device)) else: grad_data = param.grad.data if self.sparse_gradients_enabled( diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index edfd15be3221..600394cc7802 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -11,15 +11,15 @@ from torch.autograd import Variable import collections -from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter -from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS -from deepspeed.utils import logger - +from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False +from deepspeed.pt.log_utils import logger + try: from apex_C import flatten from apex_C import unflatten @@ -128,8 +128,6 @@ def async_copy_to(obj, dev, main_stream=None): return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} elif isinstance(obj, collections.Sequence): return [async_copy_to(o, dev, main_stream) for o in obj] - else: - return obj class FP16_DeepSpeedZeroOptimizer(object): @@ -166,7 +164,6 @@ def __init__(self, if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") - logger.info(f"CPU Offload: {cpu_offload}") # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later @@ -505,34 +502,16 @@ def independent_gradient_partition_epilogue(self): if self.cpu_offload is False: for i, _ in enumerate(self.fp16_groups): - if not i in self.averaged_gradients or self.averaged_gradients[i] is None: - self.averaged_gradients[i] = self.get_flat_partition( - self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) - else: - #When gradient accumulation is greater that 1 - #This code path will be triggered and will add - #to the accumulated averaged gradients - avg_new = self.get_flat_partition(self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) - - for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new): - accumulated_grad.add_(new_avg_grad) + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) self._release_ipg_buffers() - # No need to keep the gradients anymore. - # All gradients required by the step - # are in self.averaged_gradients - self.zero_grad() see_memory_usage(f"End ipg_epilogue") # resets all partition to no reduced @@ -1203,9 +1182,6 @@ def step(self, closure=None): if self.overflow: see_memory_usage('After overflow before clearing gradients') self.zero_grad() - for key in self.averaged_gradients: - self.averaged_gradients[key] = None - see_memory_usage('After overflow after clearing gradients') logger.info( @@ -1490,11 +1466,12 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - if torch.is_tensor(value): + #jie: torch.optim.Adam() has "step" has a key in state_dict + if key == "step": + lean_state[key] = value + else: lean_length = value.numel() - padding lean_state[key] = value[:lean_length] - else: - lean_state[key] = value return lean_state @@ -1540,6 +1517,7 @@ def state_dict(self): state_dict_tmp = async_copy_to(state_dict, 'cpu', torch.cuda.current_stream()) + state_dict = None state_dict = state_dict_tmp return state_dict @@ -1578,16 +1556,11 @@ def refresh_fp32_params(self): def _partition_base_optimizer_state(self, state_key, all_partition_states): partition_id = dist.get_rank(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group) - - if torch.is_tensor(all_partition_states[0]): - flat_merged_partitions = flatten_dense_tensors_aligned( - all_partition_states, - alignment) - dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) - return dp_partitions[partition_id] - else: - # Assume non-tensor states are not partitioned and equal across ranks, so return first one - return all_partition_states[0] + flat_merged_partitions = flatten_dense_tensors_aligned( + all_partition_states, + alignment) + dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) + return dp_partitions[partition_id] # Restore base optimizer state from checkpoint by # 1) Merging optimizer state from checkpoints of all partitions @@ -1612,10 +1585,8 @@ def _restore_base_optimizer_state(self, all_state_dict): for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - self.optimizer.state[p][key].data.copy_(saved.data) - else: - self.optimizer.state[p][key] = saved + current = self.optimizer.state[p][key] + current.data.copy_(saved.data) def load_state_dict(self, state_dict_list, From c8b1666d75c8e247853879d97ccdb35ecf69246a Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 3 Sep 2020 20:42:51 +0000 Subject: [PATCH 3/4] Merge with correctness/perf fixes --- deepspeed/runtime/engine.py | 87 +++--- deepspeed/runtime/zero/stage2.py | 464 ++++++++++++++++++++----------- 2 files changed, 338 insertions(+), 213 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index fa8ee6cb36e7..c8f836ead882 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2,36 +2,34 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' -import torch import os +import torch import warnings import torch.distributed as dist +from apex import amp from torch.nn.modules import Module from torch.distributed.distributed_c10d import _get_global_rank -from apex import amp from tensorboardX import SummaryWriter -from deepspeed.pt.deepspeed_timer import ThroughputTimer, SynchronizedWallClockTimer -from deepspeed.pt.deepspeed_zero_optimizer import FP16_DeepSpeedZeroOptimizer -from deepspeed.pt.zero_optimizer_stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 -from deepspeed.pt.log_utils import logger -import deepspeed.pt.deepspeed_checkpointing as deepspeed_activation_checkpointing - -from deepspeed.pt.fp16_optimizer import FP16_Optimizer -from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer -from deepspeed.pt.deepspeed_fused_lamb import FusedLamb -from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \ - ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS - -from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader -from deepspeed.pt.deepspeed_constants import \ +from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer +from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1 +from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing +from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer +from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer +from deepspeed.runtime.config import DeepSpeedConfig, \ + ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS +from deepspeed.runtime.dataloader import DeepSpeedDataLoader +from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - TORCH_DISTRIBUTED_DEFAULT_PORT, \ + TORCH_DISTRIBUTED_DEFAULT_PORT +from deepspeed.runtime.zero.constants import \ ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS - -import deepspeed.pt.deepspeed_lr_schedules as lr_schedules -from deepspeed.pt.deepspeed_csr_tensor import CSRTensor +from deepspeed.runtime.csr_tensor import CSRTensor +import deepspeed.runtime.lr_schedules as lr_schedules +from deepspeed.ops.lamb import FusedLamb +from deepspeed.utils import logger +from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer MEMORY_OPT_ALLREDUCE_SIZE = 500000000 SUMMARY_WRITER_DIR_NAME = "JobId" @@ -92,7 +90,7 @@ def print_configuration(args, name): logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg))) -class DeepSpeedLight(Module): +class DeepSpeedEngine(Module): r"""DeepSpeed engine for training. """ def __init__(self, @@ -106,7 +104,7 @@ def __init__(self, dist_init_required=None, collate_fn=None, config_params=None): - super(DeepSpeedLight, self).__init__() + super(DeepSpeedEngine, self).__init__() self.client_optimizer = optimizer self.client_model_parameters = model_parameters self.client_lr_scheduler = lr_scheduler @@ -313,8 +311,6 @@ def zero_contiguous_gradients(self): def zero_load_from_fp32_weights(self): return self._config.zero_config.load_from_fp32_weights - def allgather_size(self): - return self._config.allgather_size def fp16_enabled(self): return self._config.fp16_enabled @@ -508,21 +504,14 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" - if self.optimizer_name() not in [ADAM_OPTIMIZER, TORCH_ADAM_OPTIMIZER]: + if self.optimizer_name() not in [ADAM_OPTIMIZER]: assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) - if self.zero_cpu_offload(): - if self.optimizer_name() != TORCH_ADAM_OPTIMIZER: - assert self.zero_allow_untested_optimizer(), \ - 'You are using ZeRO-Offload with an untested Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' - logger.warning( - "**** You are using ZeRO-Offload with an untested optimizer, proceed with caution *****" - ) self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" @@ -544,12 +533,13 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) if self.optimizer_name() == ADAM_OPTIMIZER: - from apex.optimizers.fused_adam import FusedAdam - optimizer = FusedAdam(model_parameters, **optimizer_parameters) + if self.zero_cpu_offload(): + optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) + else: + from apex.optimizers.fused_adam import FusedAdam + optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: optimizer = FusedLamb(model_parameters, **optimizer_parameters) - elif self.optimizer_name() == TORCH_ADAM_OPTIMIZER: - optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -611,8 +601,8 @@ def _configure_zero_optimizer(self, optimizer): dp_process_group=self.data_parallel_group, mpu=self.mpu) elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: - assert self.gradient_accumulation_steps( - ) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1" + #assert self.gradient_accumulation_steps( + #) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1" optimizer = FP16_DeepSpeedZeroOptimizer( optimizer, timers=self.timers, @@ -629,7 +619,8 @@ def _configure_zero_optimizer(self, optimizer): cpu_offload=self.zero_cpu_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), - gradient_predivide_factor=self.gradient_predivide_factor()) + gradient_predivide_factor=self.gradient_predivide_factor(), + gradient_accumulation_steps=self.gradient_accumulation_steps()) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) @@ -735,15 +726,18 @@ def forward(self, *inputs, **kwargs): return loss def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - if self.is_gradient_accumulation_boundary(): + #Zero stage 2 communicates during non gradient accumulation boundaries as well + if self.zero_optimization_partition_gradients(): + self.optimizer.overlapping_partition_gradients_reduce_epilogue() + + #Communicate only at gradient accumulation boundaries + elif self.is_gradient_accumulation_boundary(): if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: assert self.zero_reduce_scatter() self.optimizer.reduce_scatter_gradients( postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), gradient_average=self.gradient_average) - elif self.zero_optimization_partition_gradients(): - self.optimizer.overlapping_partition_gradients_reduce_epilogue() else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) @@ -787,6 +781,7 @@ def backward(self, loss, allreduce_gradients=True): self.timers('backward_inner').start() if self.zero_optimization(): + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() self.optimizer.backward(loss) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries @@ -1029,10 +1024,10 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) # rank is reducing the same size. In some cases it may make # sense in the future to support the ability to average not # w.r.t. world size but with a different value. - grads.append( - torch.zeros(param.size(), - dtype=param.dtype, - device=param.device)) + param.grad = torch.zeros(param.size(), + dtype=param.dtype, + device=param.device) + grads.append(param.grad.data) else: grad_data = param.grad.data if self.sparse_gradients_enabled( diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 600394cc7802..d4fe4b436cdf 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -9,16 +9,17 @@ import math from torch._six import inf from torch.autograd import Variable + import collections -from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter -from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS +from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS +from deepspeed.utils import logger #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False -from deepspeed.pt.log_utils import logger try: from apex_C import flatten @@ -98,38 +99,6 @@ def move_to_cpu(tensor_list): def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") - -#jie:asyn move to target device -def async_migrate_to(obj, dev, main_stream=None): - if torch.is_tensor(obj): - obj = Variable(obj) - if isinstance(obj, Variable): - v = obj.cuda(dev, async=True) - if main_stream is not None: - v.data.record_stream(main_stream) - return v - elif isinstance(obj, collections.Mapping): - return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} - elif isinstance(obj, collections.Sequence): - return [async_copy_to(o, dev, main_stream) for o in obj] - else: - return obj - - -def async_copy_to(obj, dev, main_stream=None): - if torch.is_tensor(obj): - obj = Variable(obj) - if isinstance(obj, Variable): - target = torch.empty_like(obj, device=dev).copy_(obj) - if main_stream is not None: - target.data.record_stream(main_stream) - return target - elif isinstance(obj, collections.Mapping): - return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} - elif isinstance(obj, collections.Sequence): - return [async_copy_to(o, dev, main_stream) for o in obj] - - class FP16_DeepSpeedZeroOptimizer(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -159,11 +128,13 @@ def __init__(self, clip_grad=0.0, allreduce_always_fp32=False, postscale_gradients=True, - gradient_predivide_factor=1.0): + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1): if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") + logger.info(f"CPU Offload: {cpu_offload}") # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later @@ -184,6 +155,7 @@ def __init__(self, self.overlap_comm = overlap_comm self.cpu_offload = cpu_offload + self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group @@ -201,6 +173,8 @@ def __init__(self, self.allreduce_always_fp32 = allreduce_always_fp32 self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = 0 if self.reduce_scatter: assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" @@ -219,8 +193,7 @@ def __init__(self, #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.single_partition_of_fp32_groups = [] - if cpu_offload: - self.averaged_gradients_on_cpu = {} + #param partition info #These are the parameters in each group that will not be updated by this process directly @@ -243,7 +216,6 @@ def __init__(self, # padding on each partition for alignment purposes self.groups_padding = [] - # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -289,24 +261,13 @@ def __init__(self, self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) # a partition of the fp32 master weights that will be updated by this process - if self.cpu_offload: - self.single_partition_of_fp32_groups.append( - async_copy_to(self.parallel_partitioned_fp16_groups[i][partition_id], - 'cpu').float()) - self.averaged_gradients_on_cpu[i] = [ - torch.empty_like( - self.parallel_partitioned_fp16_groups[i][partition_id], - device='cpu') - ] - else: - self.single_partition_of_fp32_groups.append( - self.parallel_partitioned_fp16_groups[i] - [partition_id].clone().float().detach()) + self.single_partition_of_fp32_groups.append( + self.parallel_partitioned_fp16_groups[i] + [partition_id].clone().float().detach().to(self.device)) # modify optimizer of have flat master weight self.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it - param_group['params'] = [self.single_partition_of_fp32_groups[i]] partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( @@ -340,9 +301,13 @@ def __init__(self, self._release_ipg_buffers() self.previous_reduced_grads = None + + + #simplified param id self.param_id = {} + largest_param_numel = 0 count = 0 for i, params_group in enumerate(self.fp16_groups): for param in params_group: @@ -350,6 +315,8 @@ def __init__(self, self.param_id[unique_id] = count self.param_dict[count] = param self.params_already_reduced.append(False) + if param.numel() > largest_param_numel: + largest_param_numel = param.numel() count = count + 1 for param_group in self.params_in_partition: @@ -360,6 +327,18 @@ def __init__(self, for param in param_group: self.is_param_in_current_partition[self.get_param_id(param)] = False + if self.cpu_offload: + self.accumulated_grads_in_cpu={} + self.norm_for_param_grads={} + self.local_overflow=False + self.grad_position = {} + self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel).half().pin_memory() + self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel, device=torch.cuda.current_device()).half() + + for i, params_group in enumerate(self.fp16_groups): + self.get_grad_position(i,self.params_in_partition[i],self.first_offset[i], self.partition_size[i]) + + #mapping from parameter to partition that it belongs to self.param_to_partition_ids = {} @@ -429,22 +408,17 @@ def _release_ipg_buffers(self): def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): - if self.cpu_offload: - single_grad_partition = torch.zeros( - int(self.partition_size[i]), - dtype=self.single_partition_of_fp32_groups[i].dtype, - device='cpu') - else: - single_grad_partition = torch.zeros( - int(self.partition_size[i]), - dtype=self.single_partition_of_fp32_groups[i].dtype, - device=torch.cuda.current_device()) - self.single_partition_of_fp32_groups[i].grad = single_grad_partition + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=self.device) + self.single_partition_of_fp32_groups[i].grad = single_grad_partition.pin_memory() if self.cpu_offload else single_grad_partition self.optimizer.step() - for group in self.single_partition_of_fp32_groups: - group.grad = None + if not self.cpu_offload: + for group in self.single_partition_of_fp32_groups: + group.grad = None return @@ -502,16 +476,30 @@ def independent_gradient_partition_epilogue(self): if self.cpu_offload is False: for i, _ in enumerate(self.fp16_groups): - self.averaged_gradients[i] = self.get_flat_partition( - self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) + if not i in self.averaged_gradients or self.averaged_gradients[i] is None: + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) + else: + avg_new = self.get_flat_partition(self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) + for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new): + accumulated_grad.add_(new_avg_grad) self._release_ipg_buffers() + # No need to keep the gradients anymore. + # All gradients required by the step + # are in self.averaged_gradients + self.zero_grad() see_memory_usage(f"End ipg_epilogue") # resets all partition to no reduced @@ -617,7 +605,6 @@ def report_ipg_memory_usage(self, tag, param_elems): ###############Idependent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) @@ -642,15 +629,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.elements_in_ipg_bucket, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) - #jie: - if self.cpu_offload: - with stream(self.migration_stream): - averaged_gradients_on_cpu[i][param_id] = async_copy_to( - new_grad_tensor.data.view_as(param.grad), - 'cpu', - self.cpu_computation_stream) - else: - param.grad.data = new_grad_tensor.data.view_as(param.grad) + param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() self.grads_in_ipg_bucket.append(param.grad) @@ -752,8 +731,161 @@ def average_tensor(self, tensor): for handle in async_handles: handle.wait() + ############################################################################## + ############################# CPU Offload Methods############################# + ############################################################################## + def get_grad_position(self, group_id, tensor_list, first_offset, partition_size): + current_offset = 0 + + for i, tensor in enumerate(tensor_list): + param_id = self.get_param_id(tensor) + param_start_offset = 0 + + num_elements = tensor.numel() + tensor_offset = 0 + + #we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + param_start_offset = first_offset + + #we dont need all elements of the tensor + if num_elements > (partition_size - current_offset): + num_elements = partition_size - current_offset + + self.grad_position[param_id] = [int(group_id), int(param_start_offset), int(current_offset), int(num_elements)] + current_offset += num_elements + + + def update_overflow_tracker_for_param_grad(self, param): + if param.grad is not None and self._has_inf_or_nan(param.grad.data): + self.local_overflow = True + + def async_accumulate_grad_in_cpu(self, param): + param_id = self.get_param_id(param) + + #copy to a preexisiting buffer to avoid memory allocation penalty + dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(0, 0, param.numel()) + dest_buffer.copy_(param.grad.view(-1), non_blocking=True) + + if param_id not in self.accumulated_grads_in_cpu: + self.accumulated_grads_in_cpu[param_id] = torch.zeros(param.numel(),dtype=param.dtype).pin_memory() + + self.accumulated_grads_in_cpu[param_id].add_(dest_buffer) + + def async_accumulate_grad_in_cpu_via_gpu(self, param): + param_id = self.get_param_id(param) + + #copy to a preexisiting buffer to avoid memory allocation penalty + dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel()) + + if param_id not in self.accumulated_grads_in_cpu: + self.accumulated_grads_in_cpu[param_id] = torch.zeros(param.numel(),dtype=param.dtype).pin_memory() + + if self.micro_step_id > 0: + dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) + param.grad.data.view(-1).add_(dest_buffer) + + #at the boundary we will send 32bit directly + if not self.is_gradient_accumulation_boundary: + self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), non_blocking=True) + + + def set_norm_for_param_grad(self, param): + param_id = self.get_param_id(param) + accumulated_grad = self.accumulated_grads_in_cpu[param_id] if self.gradient_accumulation_steps > 1 else param.grad + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + start = source_offset + accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) + + self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2) + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + accumulated_grad = param.grad + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + start = source_offset + accumulated_grad = accumulated_grad.view(-1).narrow(0, start, num_elements) + + self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(2) + + def async_inplace_copy_grad_to_fp32_buffer(self, param): + param_id = self.get_param_id(param) + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + if self.gradient_accumulation_steps > 1: + src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow(0, source_offset, num_elements) + else: + src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float() + dest_tensor.copy_(src_tensor,non_blocking=True) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): + param_id = self.get_param_id(param) + + [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] + + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + + src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float() + dest_tensor.copy_(src_tensor,non_blocking=True) + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + ############################################################################################ def copy_grads_in_partition(self, param): + if self.cpu_offload: + #print(f"GAS: {self.gradient_accumulation_steps}") + #print(f"GAS: {self.is_gradient_accumulation_boundary}") + #with torch.cuda.stream(torch.cuda.current_stream()): + + self.update_overflow_tracker_for_param_grad(param) + + if self.gradient_accumulation_steps > 1: + self.async_accumulate_grad_in_cpu_via_gpu(param) + + if self.is_gradient_accumulation_boundary: + self.set_norm_for_param_grad_in_gpu(param) + self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param) + + #new_grad_tensor = async_copy_to(param.grad.view(-1), + # 'cpu', + # self.cpu_computation_stream) + #param.grad.data = new_grad_tensor.data.view_as(param.grad) + return + if self.grads_in_partition is None: self.grads_in_partition_offset = 0 total_size = 0 @@ -762,30 +894,16 @@ def copy_grads_in_partition(self, param): total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") - #jie: - ''' - if self.cpu_offload: - self.grads_in_partition = torch.empty(int(total_size), - dtype=torch.half, - device=torch.cuda.current_device()) - '' - else: - ''' self.grads_in_partition = torch.empty(int(total_size), dtype=torch.half, device=torch.cuda.current_device()) see_memory_usage(f"after copying {total_size} gradients into partition") #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer - new_grad_tensor = self.grads_in_partition.narrow(0, + new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, self.grads_in_partition_offset, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) - ''' - if self.cpu_offload: - averaged_gradients_on_cpu[self.get_param_id()] = new_grad_tensor.data.view_as(param.grad) - else: - ''' param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grads_in_partition_offset += param.numel() @@ -1166,10 +1284,21 @@ def free_grad_in_param_list(self, param_list): for p in param_list: p.grad = None + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + with torch.cuda.stream(self.migration_stream): + for key, value in self.accumulated_grads_in_cpu.items(): + value.mul_(0.0) + def step(self, closure=None): """ Not supporting closure. """ + self.micro_step_id = 0 + + if self.cpu_offload: + torch.cuda.current_stream().wait_stream(self.migration_stream) + see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow @@ -1182,6 +1311,9 @@ def step(self, closure=None): if self.overflow: see_memory_usage('After overflow before clearing gradients') self.zero_grad() + if self.cpu_offload: + self.reset_cpu_buffers() + see_memory_usage('After overflow after clearing gradients') logger.info( @@ -1200,64 +1332,56 @@ def step(self, closure=None): skip = False partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): - if self.cpu_offload: - self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] - #norm_groups.append( - # self.get_grad_norm_direct(self.averaged_gradients[i], - # self.params_in_partition[i])) - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.params_in_partition[i])) - - #free gradients for all the prameters that are not updated by this process - self.free_grad_in_param_list(self.params_not_in_partition[i]) - - #create a flat gradients for parameters updated by this process - # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - single_grad_partition = flatten_dense_tensors_aligned( - self.averaged_gradients[i], - int(self.partition_size[i])).to( - self.single_partition_of_fp32_groups[i].dtype) + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])) + single_grad_partition = self.single_partition_of_fp32_groups[i].grad else: - single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[i]).to( - self.single_partition_of_fp32_groups[i].dtype) - assert single_grad_partition.numel() == self.partition_size[i], \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.params_in_partition[i])) + + #free gradients for all the prameters that are not updated by this process + self.free_grad_in_param_list(self.params_not_in_partition[i]) + + #create a flat gradients for parameters updated by this process + # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + single_grad_partition = flatten_dense_tensors_aligned( + self.averaged_gradients[i], + int(self.partition_size[i])).to( + self.single_partition_of_fp32_groups[i].dtype) + else: + single_grad_partition = _flatten_dense_tensors( + self.averaged_gradients[i]).to( + self.single_partition_of_fp32_groups[i].dtype) + assert single_grad_partition.numel() == self.partition_size[i], \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) - self.single_partition_of_fp32_groups[i].grad = single_grad_partition - #release all the gradient since we have already created a necessary copy in dp_grad_partition - self.free_grad_in_param_list(self.params_in_partition[i]) + self.single_partition_of_fp32_groups[i].grad = single_grad_partition + #release all the gradient since we have already created a necessary copy in dp_grad_partition + self.free_grad_in_param_list(self.params_in_partition[i]) - if self.cpu_offload is False: self.averaged_gradients[i] = None single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) - - self.cpu_computation_stream.wait_stream(self.migration_stream) + #torch.set_num_threads(12) timers('optimizer_step').start() - with torch.cuda.stream(self.cpu_computation_stream): - self.optimizer.step() + self.optimizer.step() + #get rid of the fp32 gradients. Not needed anymore + if not self.cpu_offload: + for group in self.single_partition_of_fp32_groups: + group.grad = None - if self.cpu_offload: - stream = torch.cuda.current_stream() - with torch.cuda.stream(self.migration_stream): - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id] = async_copy_to( - fp32_partition, - torch.cuda.current_device(), - stream) - #for averaged_gradients_cpu, fp32_partition in zip(self.averaged_gradients_on_cpu, self.single_partition_of_fp32_groups): - # averaged_gradients_cpu = [fp32_partition] - else: - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) timers('optimizer_step').stop() + if self.cpu_offload: + self.reset_cpu_buffers() + timers('optimizer_allgather').start() #gather the updated weights from everyone for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): @@ -1338,8 +1462,6 @@ def has_overflow_serial(self, params, is_grad_list=False): def has_overflow_partitioned_grads_serial(self): for i in range(len(self.fp16_groups)): - if self.cpu_offload: - self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True @@ -1347,7 +1469,7 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: - overflow = self.has_overflow_partitioned_grads_serial() + overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, @@ -1400,22 +1522,26 @@ def backward(self, loss, retain_graph=False): 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ + if self.cpu_offload: + torch.cuda.current_stream().wait_stream(self.migration_stream) + if self.contiguous_gradients: self.ipg_buffer = [] - buf_0 = torch.empty(self.reduce_bucket_size, + buf_0 = torch.empty(int(self.reduce_bucket_size*4.5), dtype=torch.half, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: - buf_1 = torch.empty(self.reduce_bucket_size, + buf_1 = torch.empty(int(self.reduce_bucket_size*4.5), dtype=torch.half, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 - torch.cuda.empty_cache() + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + self.micro_step_id += 1 def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) @@ -1466,12 +1592,11 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - #jie: torch.optim.Adam() has "step" has a key in state_dict - if key == "step": - lean_state[key] = value - else: + if torch.is_tensor(value): lean_length = value.numel() - padding lean_state[key] = value[:lean_length] + else: + lean_state[key] = value return lean_state @@ -1513,12 +1638,11 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding - if self.cpu_offload: - state_dict_tmp = async_copy_to(state_dict, - 'cpu', - torch.cuda.current_stream()) - state_dict = None - state_dict = state_dict_tmp +# if self.cpu_offload: +# state_dict_tmp = async_copy_to(state_dict, +# 'cpu', +# torch.cuda.current_stream()) +# state_dict = state_dict_tmp return state_dict @@ -1556,11 +1680,15 @@ def refresh_fp32_params(self): def _partition_base_optimizer_state(self, state_key, all_partition_states): partition_id = dist.get_rank(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group) - flat_merged_partitions = flatten_dense_tensors_aligned( - all_partition_states, - alignment) - dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) - return dp_partitions[partition_id] + if torch.is_tensor(all_partition_states[0]): + flat_merged_partitions = flatten_dense_tensors_aligned( + all_partition_states, + alignment) + dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) + return dp_partitions[partition_id] + else: + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return all_partition_states[0] # Restore base optimizer state from checkpoint by # 1) Merging optimizer state from checkpoints of all partitions @@ -1585,8 +1713,10 @@ def _restore_base_optimizer_state(self, all_state_dict): for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] for key, saved in base_optimizer_group_states[i].items(): - current = self.optimizer.state[p][key] - current.data.copy_(saved.data) + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved def load_state_dict(self, state_dict_list, From 2c71e4995ab9a9817f0c894caccd15309df987d5 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 3 Sep 2020 20:57:07 +0000 Subject: [PATCH 4/4] Formatting fixes --- deepspeed/runtime/engine.py | 9 +-- deepspeed/runtime/zero/stage2.py | 123 ++++++++++++++++++++----------- 2 files changed, 82 insertions(+), 50 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c8f836ead882..f77fc64de7d2 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -9,7 +9,6 @@ from apex import amp from torch.nn.modules import Module from torch.distributed.distributed_c10d import _get_global_rank - from tensorboardX import SummaryWriter from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer @@ -311,7 +310,6 @@ def zero_contiguous_gradients(self): def zero_load_from_fp32_weights(self): return self._config.zero_config.load_from_fp32_weights - def fp16_enabled(self): return self._config.fp16_enabled @@ -601,8 +599,6 @@ def _configure_zero_optimizer(self, optimizer): dp_process_group=self.data_parallel_group, mpu=self.mpu) elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: - #assert self.gradient_accumulation_steps( - #) == 1, "ZeRO stage 2 does not support gradient accumulation, if you need gradient accumulation please use stage 1" optimizer = FP16_DeepSpeedZeroOptimizer( optimizer, timers=self.timers, @@ -729,7 +725,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): #Zero stage 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() - + #Communicate only at gradient accumulation boundaries elif self.is_gradient_accumulation_boundary(): if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES: @@ -781,7 +777,8 @@ def backward(self, loss, allreduce_gradients=True): self.timers('backward_inner').start() if self.zero_optimization(): - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( + ) self.optimizer.backward(loss) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index d4fe4b436cdf..69cdcf34adf1 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -20,7 +20,6 @@ #with gradient partitioning and without pg_correctness_test = False - try: from apex_C import flatten from apex_C import unflatten @@ -99,6 +98,7 @@ def move_to_cpu(tensor_list): def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") + class FP16_DeepSpeedZeroOptimizer(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -301,9 +301,6 @@ def __init__(self, self._release_ipg_buffers() self.previous_reduced_grads = None - - - #simplified param id self.param_id = {} @@ -328,16 +325,21 @@ def __init__(self, self.is_param_in_current_partition[self.get_param_id(param)] = False if self.cpu_offload: - self.accumulated_grads_in_cpu={} - self.norm_for_param_grads={} - self.local_overflow=False + self.accumulated_grads_in_cpu = {} + self.norm_for_param_grads = {} + self.local_overflow = False self.grad_position = {} - self.temp_grad_buffer_for_cpu_offload = torch.zeros(largest_param_numel).half().pin_memory() - self.temp_grad_buffer_for_gpu_offload = torch.zeros(largest_param_numel, device=torch.cuda.current_device()).half() + self.temp_grad_buffer_for_cpu_offload = torch.zeros( + largest_param_numel).half().pin_memory() + self.temp_grad_buffer_for_gpu_offload = torch.zeros( + largest_param_numel, + device=torch.cuda.current_device()).half() for i, params_group in enumerate(self.fp16_groups): - self.get_grad_position(i,self.params_in_partition[i],self.first_offset[i], self.partition_size[i]) - + self.get_grad_position(i, + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i]) #mapping from parameter to partition that it belongs to self.param_to_partition_ids = {} @@ -412,7 +414,9 @@ def initialize_optimizer_states(self): int(self.partition_size[i]), dtype=self.single_partition_of_fp32_groups[i].dtype, device=self.device) - self.single_partition_of_fp32_groups[i].grad = single_grad_partition.pin_memory() if self.cpu_offload else single_grad_partition + self.single_partition_of_fp32_groups[ + i].grad = single_grad_partition.pin_memory( + ) if self.cpu_offload else single_grad_partition self.optimizer.step() @@ -731,6 +735,7 @@ def average_tensor(self, tensor): for handle in async_handles: handle.wait() + ############################################################################## ############################# CPU Offload Methods############################# ############################################################################## @@ -754,9 +759,13 @@ def get_grad_position(self, group_id, tensor_list, first_offset, partition_size) if num_elements > (partition_size - current_offset): num_elements = partition_size - current_offset - self.grad_position[param_id] = [int(group_id), int(param_start_offset), int(current_offset), int(num_elements)] - current_offset += num_elements - + self.grad_position[param_id] = [ + int(group_id), + int(param_start_offset), + int(current_offset), + int(num_elements) + ] + current_offset += num_elements def update_overflow_tracker_for_param_grad(self, param): if param.grad is not None and self._has_inf_or_nan(param.grad.data): @@ -766,11 +775,16 @@ def async_accumulate_grad_in_cpu(self, param): param_id = self.get_param_id(param) #copy to a preexisiting buffer to avoid memory allocation penalty - dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow(0, 0, param.numel()) + dest_buffer = self.temp_grad_buffer_for_cpu_offload.view(-1).narrow( + 0, + 0, + param.numel()) dest_buffer.copy_(param.grad.view(-1), non_blocking=True) if param_id not in self.accumulated_grads_in_cpu: - self.accumulated_grads_in_cpu[param_id] = torch.zeros(param.numel(),dtype=param.dtype).pin_memory() + self.accumulated_grads_in_cpu[param_id] = torch.zeros( + param.numel(), + dtype=param.dtype).pin_memory() self.accumulated_grads_in_cpu[param_id].add_(dest_buffer) @@ -778,23 +792,30 @@ def async_accumulate_grad_in_cpu_via_gpu(self, param): param_id = self.get_param_id(param) #copy to a preexisiting buffer to avoid memory allocation penalty - dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(0, 0, param.numel()) + dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( + 0, + 0, + param.numel()) if param_id not in self.accumulated_grads_in_cpu: - self.accumulated_grads_in_cpu[param_id] = torch.zeros(param.numel(),dtype=param.dtype).pin_memory() + self.accumulated_grads_in_cpu[param_id] = torch.zeros( + param.numel(), + dtype=param.dtype).pin_memory() if self.micro_step_id > 0: - dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True) + dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), + non_blocking=True) param.grad.data.view(-1).add_(dest_buffer) #at the boundary we will send 32bit directly if not self.is_gradient_accumulation_boundary: - self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), non_blocking=True) - + self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad.data.view(-1), + non_blocking=True) def set_norm_for_param_grad(self, param): param_id = self.get_param_id(param) - accumulated_grad = self.accumulated_grads_in_cpu[param_id] if self.gradient_accumulation_steps > 1 else param.grad + accumulated_grad = self.accumulated_grads_in_cpu[ + param_id] if self.gradient_accumulation_steps > 1 else param.grad [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] @@ -819,22 +840,33 @@ def async_inplace_copy_grad_to_fp32_buffer(self, param): [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] - dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow( + 0, + dest_offset, + num_elements) if self.gradient_accumulation_steps > 1: - src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow(0, source_offset, num_elements) + src_tensor = self.accumulated_grads_in_cpu[param_id].view(-1).narrow( + 0, + source_offset, + num_elements) else: - src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float() - dest_tensor.copy_(src_tensor,non_blocking=True) + src_tensor = param.grad.view(-1).narrow(0, + source_offset, + num_elements).float() + dest_tensor.copy_(src_tensor, non_blocking=True) def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param): param_id = self.get_param_id(param) [i, source_offset, dest_offset, num_elements] = self.grad_position[param_id] - dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements) + dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow( + 0, + dest_offset, + num_elements) src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float() - dest_tensor.copy_(src_tensor,non_blocking=True) + dest_tensor.copy_(src_tensor, non_blocking=True) def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = 0.0 @@ -849,8 +881,8 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) @@ -900,9 +932,10 @@ def copy_grads_in_partition(self, param): see_memory_usage(f"after copying {total_size} gradients into partition") #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer - new_grad_tensor = self.grads_in_partition.view(-1).narrow(0, - self.grads_in_partition_offset, - param.numel()) + new_grad_tensor = self.grads_in_partition.view(-1).narrow( + 0, + self.grads_in_partition_offset, + param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grads_in_partition_offset += param.numel() @@ -1334,12 +1367,13 @@ def step(self, closure=None): for i, group in enumerate(self.fp16_groups): if self.cpu_offload: norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload(self.params_in_partition[i])) + self.complete_grad_norm_calculation_for_cpu_offload( + self.params_in_partition[i])) single_grad_partition = self.single_partition_of_fp32_groups[i].grad else: norm_groups.append( self.get_grad_norm_direct(self.averaged_gradients[i], - self.params_in_partition[i])) + self.params_in_partition[i])) #free gradients for all the prameters that are not updated by this process self.free_grad_in_param_list(self.params_not_in_partition[i]) @@ -1469,7 +1503,8 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: - overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial() + overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial( + ) overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, @@ -1527,14 +1562,14 @@ def backward(self, loss, retain_graph=False): if self.contiguous_gradients: self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size*4.5), + buf_0 = torch.empty(int(self.reduce_bucket_size * 4.5), dtype=torch.half, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: - buf_1 = torch.empty(int(self.reduce_bucket_size*4.5), + buf_1 = torch.empty(int(self.reduce_bucket_size * 4.5), dtype=torch.half, device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) @@ -1638,11 +1673,11 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding -# if self.cpu_offload: -# state_dict_tmp = async_copy_to(state_dict, -# 'cpu', -# torch.cuda.current_stream()) -# state_dict = state_dict_tmp + # if self.cpu_offload: + # state_dict_tmp = async_copy_to(state_dict, + # 'cpu', + # torch.cuda.current_stream()) + # state_dict = state_dict_tmp return state_dict