From 64157382adf5faa7d32a5b39c9dc1ff5b62c20b8 Mon Sep 17 00:00:00 2001 From: jren73 Date: Mon, 17 Aug 2020 21:28:49 +0000 Subject: [PATCH] deleted: deepspeed_cpu_adam.py modified: deepspeed_light.py modified: deepspeed_zero_optimizer.py ../../deepspeed_zero_optimizer_cpu_offload.py --- deepspeed/pt/deepspeed_cpu_adam.py | 200 ------------ deepspeed/pt/deepspeed_light.py | 19 +- deepspeed/pt/deepspeed_zero_optimizer.py | 392 ++++++----------------- deepspeed/pt/deepspeed_zero_utils.py | 32 ++ 4 files changed, 149 insertions(+), 494 deletions(-) delete mode 100755 deepspeed/pt/deepspeed_cpu_adam.py create mode 100755 deepspeed/pt/deepspeed_zero_utils.py diff --git a/deepspeed/pt/deepspeed_cpu_adam.py b/deepspeed/pt/deepspeed_cpu_adam.py deleted file mode 100755 index 9c910889d467..000000000000 --- a/deepspeed/pt/deepspeed_cpu_adam.py +++ /dev/null @@ -1,200 +0,0 @@ -import math -import torch - - -class CPUAdam(torch.optim.Optimizer): - r"""Implements Adam algorithm. - - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - def __init__(self, - params, - lr=1e-3, - betas=(0.9, - 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - amsgrad=amsgrad) - super(CPUAdam, self).__init__(params, defaults) - - def __setstate__(self, state): - super(CPUAdam, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - def step_with_cpuoffload(self, - closure=None, - fp32_params=None, - fp32_params_grad=None, - exp_avg=None, - exp_avg_sq=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - fp32_params: fp32 params on CPU - fp32_params_grad: the normolized gradients in fp32 groups - exp_avg: optimizer state - exp_avg_sq: optimizer state - """ - loss = None - if closure is not None: - loss = closure() - - if fp32_params is None: - raise RuntimeError('params is None') - - index = 0 - - for group in self.param_groups: - group_size = sum([t.numel() for t in group['params']]) - p = torch.zeros(group_size, device=torch.device('cpu'), requires_grad=True) - p = fp32_params[index:index + group_size].detach() - p_grad = torch.zeros(group_size, device=torch.device('cpu')) - p_grad = fp32_params_grad[index:index + group_size].detach() - p.grad = p_grad - if p.grad is None: - continue - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError( - 'Adam does not support sparse gradients, please consider SparseAdam instead' - ) - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - - beta1, beta2 = group['betas'] - - state['step'] += 1 - - if group['weight_decay'] != 0: - grad.add_(group['weight_decay'], p.data) - - # Decay the first and second moment running average coefficient - exp_avg[index:index + group_size].mul_(beta1).add_(1 - beta1, grad) - exp_avg_sq[index:index + group_size].mul_(beta2).addcmul_( - 1 - beta2, - grad, - grad) - - denom = exp_avg_sq[index:index + group_size].sqrt().add_(group['eps']) - - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 - - p.data.addcdiv_(-step_size, exp_avg[index:index + group_size], denom) - - index += group_size - - return loss - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - 'Adam does not support sparse gradients, please consider SparseAdam instead' - ) - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, - memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, - memory_format=torch.preserve_format) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, - memory_format=torch.preserve_format) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - - if group['weight_decay'] != 0: - grad = grad.add(p, alpha=group['weight_decay']) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( - group['eps']) - else: - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( - group['eps']) - - step_size = group['lr'] / bias_correction1 - - p.addcdiv_(exp_avg, denom, value=-step_size) - - return loss diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 6539deef3c49..e367b9fc6ff3 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -21,7 +21,6 @@ from deepspeed.pt.fp16_optimizer import FP16_Optimizer from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.pt.deepspeed_fused_lamb import FusedLamb -from deepspeed.pt.deepspeed_cpu_adam import CPUAdam from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \ ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS @@ -170,6 +169,8 @@ def __init__(self, self.optimizer = None self.lr_scheduler = None if model_parameters or optimizer: + if "torch.optim" in self.optimizer_name(): + self.zero_set_cpu_offload() self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler(lr_scheduler) self._report_progress(0) @@ -263,10 +264,13 @@ def sparse_gradients_enabled(self): def train_batch_size(self): return self._config.train_batch_size + def train_micro_batch_size_per_gpu(self): return self._config.train_micro_batch_size_per_gpu def optimizer_name(self): + print("===========================") + print(self._config.optimizer_name) return self._config.optimizer_name def optimizer_params(self): @@ -296,6 +300,9 @@ def zero_overlap_comm(self): def zero_cpu_offload(self): return self._config.zero_config.cpu_offload + def zero_set_cpu_offload(self): + self._config.zero_config.cpu_offload = True + def zero_optimization_stage(self): return self._config.zero_optimization_stage @@ -498,9 +505,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters): #jie: if self.zero_cpu_offload(): optimizer_parameters = self.optimizer_params() - basic_optimizer = CPUAdam(client_optimizer.param_groups, + basic_optimizer = torch.optim.Adam(client_optimizer.param_groups, **optimizer_parameters) - logger.info('Using CPU Optimizer as basic optimizer') + logger.info('Using CPU Optimizer as basic optimizer' elif client_optimizer is not None: basic_optimizer = client_optimizer logger.info('Using client Optimizer as basic optimizer') @@ -542,9 +549,6 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) if self.optimizer_name() == ADAM_OPTIMIZER: - if self.zero_cpu_offload(): - optimizer = CPUAdam(model_parameters, **optimizer_parameters) - else: from apex.optimizers.fused_adam import FusedAdam optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: @@ -958,6 +962,9 @@ def _get_optimizer_param(self, param_name): def get_lr(self): return self._get_optimizer_param('lr') + def get_type(self): + return self._get_optimizer_param('type') + def get_mom(self): return self._get_optimizer_param('betas') diff --git a/deepspeed/pt/deepspeed_zero_optimizer.py b/deepspeed/pt/deepspeed_zero_optimizer.py index d926e57085c3..7cdc61936568 100755 --- a/deepspeed/pt/deepspeed_zero_optimizer.py +++ b/deepspeed/pt/deepspeed_zero_optimizer.py @@ -14,7 +14,6 @@ from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS - #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False @@ -96,15 +95,10 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() -def move_to_cuda(tensor_list, cuda_device): - for tensor in tensor_list: - tensor.data = tensor.data.to(cuda_device) - - def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") - +#jie:asyn move to target device def async_migrate_to(obj, dev, main_stream=None): if torch.is_tensor(obj): obj = Variable(obj) @@ -134,7 +128,6 @@ def async_copy_to(obj, dev, main_stream=None): elif isinstance(obj, collections.Sequence): return [async_copy_to(o, dev, main_stream) for o in obj] - class FP16_DeepSpeedZeroOptimizer(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -188,10 +181,10 @@ def __init__(self, self.overlap_comm = overlap_comm - self.dp_process_group = dp_process_group - self.cpu_offload = cpu_offload + self.dp_process_group = dp_process_group + self.partition_count = dist.get_world_size(group=self.dp_process_group) if mpu is None: @@ -208,7 +201,6 @@ def __init__(self, self.postscale_gradients = postscale_gradients if self.reduce_scatter: - assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" @@ -222,26 +214,11 @@ def __init__(self, #each of which will be updated by a different process self.parallel_partitioned_fp16_groups = [] - #jie: #a single 32-bit partition of the parallel partitioned parameters #that this process will update - - self.total_params = 0 - self.single_partition_of_fp32_groups_index = 0 - - if self.cpu_offload: - logger.info(f"cpu_offload enabled") - for i, param_group in enumerate(self.optimizer.param_groups): - self.total_params += sum([t.numel() for t in param_group['params']]) - logger.info("Total # of params is {}".format(self.total_params)) - self.single_partition_of_fp32_groups = torch.zeros( - [self.total_params + self.total_params % self.partition_count], - device=torch.device('cpu'), - requires_grad=True) - - else: - logger.info(f"cpu_offload disabled") - self.single_partition_of_fp32_groups = [] + self.single_partition_of_fp32_groups = [] + if cpu_offload: + self.averaged_gradients_on_cpu = {} #param partition info #These are the parameters in each group that will not be updated by this process directly @@ -264,6 +241,7 @@ def __init__(self, # padding on each partition for alignment purposes self.groups_padding = [] + # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -308,27 +286,24 @@ def __init__(self, self.fp16_groups_flat[i]) self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) - #that this process will update + # a partition of the fp32 master weights that will be updated by this process if self.cpu_offload: - start = self.single_partition_of_fp32_groups_index - self.single_partition_of_fp32_groups_index += self.parallel_partitioned_fp16_groups[ - i][partition_id].numel() - end = self.single_partition_of_fp32_groups_index - #self.single_partition_of_fp32_groups[start:end] = torch.tensor(self.parallel_partitioned_fp16_groups[i][partition_id].detach(), - # dtype=self.single_partition_of_fp32_groups[i].dtype) - self.single_partition_of_fp32_groups[ - start:end] = self.parallel_partitioned_fp16_groups[i][ - partition_id].detach() - + self.single_partition_of_fp32_groups.append( + async_copy_to(self.parallel_partitioned_fp16_groups[i] + [partition_id], 'cpu').float() + ) + self.averaged_gradients_on_cpu[i] = [torch.empty_like(self.parallel_partitioned_fp16_groups[i][partition_id], + device='cpu')] else: - # a partition of the fp32 master weights that will be updated by this process self.single_partition_of_fp32_groups.append( self.parallel_partitioned_fp16_groups[i] - [partition_id].clone().float()) - # modify optimizer of have flat master weight - self.single_partition_of_fp32_groups[i].requires_grad_() - # keep this in case internal optimizer uses it - param_group['params'] = [self.single_partition_of_fp32_groups[i]] + [partition_id].clone().float().detach()) + + # modify optimizer of have flat master weight + self.single_partition_of_fp32_groups[ + i].requires_grad = True # keep this in case internal optimizer uses it + + param_group['params'] = [self.single_partition_of_fp32_groups[i]] partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( group=self.dp_process_group) @@ -344,13 +319,9 @@ def __init__(self, self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.reduction_stream = torch.cuda.Stream() - self.callback_queued = False - - self.cpu_computation_event = torch.cuda.Event(enable_timing=False, - blocking=False) self.cpu_computation_stream = torch.cuda.Stream() - self.migration_stream = torch.cuda.Stream() + self.callback_queued = False self.param_dict = {} @@ -435,15 +406,9 @@ def __init__(self, self.loss_scaler = LossScaler(scale=static_loss_scale) self.cur_iter = 0 - #jie: - if self.cpu_offload: - see_memory_usage("Before initializing optimizer states") - self.initialize_optimizer_states_on_cpu() - see_memory_usage("After initializing optimizer states") - else: - see_memory_usage("Before initializing optimizer states") - self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states") + see_memory_usage("Before initializing optimizer states") + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states") if dist.get_rank() == 0: logger.info(f"optimizer state initialized") @@ -451,8 +416,6 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer") - logger.info(f"=========end init============") - def _release_ipg_buffers(self): if self.contiguous_gradients: self.ipg_buffer = None @@ -462,10 +425,16 @@ def _release_ipg_buffers(self): def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): - single_grad_partition = torch.zeros( - int(self.partition_size[i]), - dtype=self.single_partition_of_fp32_groups[i].dtype, - device=torch.cuda.current_device()) + if self.cpu_offload: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device='cpu') + else: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=torch.cuda.current_device()) self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.optimizer.step() @@ -475,36 +444,6 @@ def initialize_optimizer_states(self): return - #jie: - def initialize_optimizer_states_on_cpu(self): - # init fp32 gradients - single_grad_partition_cpu = torch.zeros( - [self.single_partition_of_fp32_groups.numel()], - dtype=self.single_partition_of_fp32_groups.dtype, - device='cpu') - - self.cpu_fp32_exp_avg = torch.zeros( - [self.single_partition_of_fp32_groups.numel()], - dtype=self.single_partition_of_fp32_groups.dtype, - device='cpu') - - self.cpu_fp32_exp_avg_sq = torch.zeros( - [self.single_partition_of_fp32_groups.numel()], - dtype=self.single_partition_of_fp32_groups.dtype, - device='cpu') - - self.single_partition_of_fp32_groups.grad = single_grad_partition_cpu - stream = self.cpu_computation_stream - #with torch.cuda.stream(stream): - self.optimizer.step_with_cpuoffload(None, - self.single_partition_of_fp32_groups, - self.single_partition_of_fp32_groups.grad, - self.cpu_fp32_exp_avg, - self.cpu_fp32_exp_avg_sq) - #return - - return - ######################################################################### #########################ZeRO Partition Gradients######################## ######################################################################### @@ -557,14 +496,15 @@ def independent_gradient_partition_epilogue(self): if self.overlap_comm: torch.cuda.synchronize() - for i, _ in enumerate(self.fp16_groups): - self.averaged_gradients[i] = self.get_flat_partition( - self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) + if self.cpu_offload is False: + for i, _ in enumerate(self.fp16_groups): + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) self._release_ipg_buffers() @@ -673,6 +613,7 @@ def report_ipg_memory_usage(self, tag, param_elems): ###############Idependent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) @@ -697,7 +638,14 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.elements_in_ipg_bucket, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + #jie: + if self.cpu_offload: + with stream(self.migration_stream): + averaged_gradients_on_cpu[i][param_id] = async_copy_to(new_grad_tensor.data.view_as(param.grad), + 'cpu', + self.cpu_computation_stream) + else: + param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() self.grads_in_ipg_bucket.append(param.grad) @@ -809,9 +757,18 @@ def copy_grads_in_partition(self, param): total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") + #jie: + ''' + if self.cpu_offload: + self.grads_in_partition = torch.empty(int(total_size), + dtype=torch.half, + device=torch.cuda.current_device()) + '' + else: + ''' self.grads_in_partition = torch.empty(int(total_size), - dtype=torch.half, - device=torch.cuda.current_device()) + dtype=torch.half, + device=torch.cuda.current_device()) see_memory_usage(f"after copying {total_size} gradients into partition") #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer @@ -819,6 +776,11 @@ def copy_grads_in_partition(self, param): self.grads_in_partition_offset, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) + ''' + if self.cpu_offload: + averaged_gradients_on_cpu[self.get_param_id()] = new_grad_tensor.data.view_as(param.grad) + else: + ''' param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grads_in_partition_offset += param.numel() @@ -1199,160 +1161,10 @@ def free_grad_in_param_list(self, param_list): for p in param_list: p.grad = None - def get_partition_fp32_group_index(self, group_index): - start = self.first_offset[group_index] - end = start + self.partition_size[group_index] - return [int(start), int(end)] - - #jie: - def step_with_cpuoffload(self, closure=None): - """ - Not supporting closure. - """ - see_memory_usage(f"In step before checking overflow") - - # First compute norm for all group so we know if there is overflow - self.check_overflow() - - timers = self.timers - - prev_scale = self.loss_scale - self._update_scale(self.overflow) - if self.overflow: - see_memory_usage('After overflow before clearing gradients') - self.zero_grad() - see_memory_usage('After overflow after clearing gradients') - - logger.info( - "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) - timers('optimizer_step').start() - timers('optimizer_step').stop() - timers('optimizer_allgather').start() - timers('optimizer_allgather').stop() - return - - norm_groups = [] - single_partition_grad_groups = [] - skip = False - partition_id = dist.get_rank(group=self.dp_process_group) - for i, group in enumerate(self.fp16_groups): - - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.params_in_partition[i])) - - #free gradients for all the prameters that are not updated by this process - self.free_grad_in_param_list(self.params_not_in_partition[i]) - - #create a flat gradients for parameters updated by this process - # If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - single_grad_partition = flatten_dense_tensors_aligned( - self.averaged_gradients[i], - int(self.partition_size[i])).to( - self.single_partition_of_fp32_groups[i].dtype) - else: - single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[i]).to( - self.single_partition_of_fp32_groups[i].dtype) - assert single_grad_partition.numel() == self.partition_size[i], \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) - - #jie: transform the grad back to CPU - start, end = self.get_partition_fp32_group_index(i) - with torch.cuda.stream(self.migration_stream): - self.single_partition_of_fp32_groups[start:end].grad = async_copy_to( - single_grad_partition, - 'cpu', - self.migration_stream) - #self.single_partition_of_fp32_groups[start:end].grad = torch.tensor(single_grad_partition.detach(), - # dtype=self.single_partition_of_fp32_groups[i].dtype) - #release all the gradient since we have already created a necessary copy in dp_grad_partition - self.free_grad_in_param_list(self.params_in_partition[i]) - - self.averaged_gradients[i] = None - - #jie: is syncization necessary??? - self.cpu_computation_stream.wait_stream(self.migration_stream) - timers('optimizer_step').start() - with torch.cuda.stream(self.cpu_computation_stream): - self.unscale_and_clip_grads_on_cpu(self.single_partition_of_fp32_groups.grad, - norm_groups) - self.optimizer.step_with_cpuoffload( - None, - self.single_partition_of_fp32_groups, - self.single_partition_of_fp32_groups.grad, - self.cpu_fp32_exp_avg, - self.cpu_fp32_exp_avg_sq) - #get rid of the fp32 gradients. Not needed anymore - #for group in self.single_partition_of_fp32_groups: - # group.grad = None - - # jie: updated weights back to GPU - with torch.cuda.stream(self.migration_stream): - #self.parallel_partitioned_fp16_groups,self.single_partition_of_fp32_groups[start:end]) - for i, fp16_partitions in enumerate(self.parallel_partitioned_fp16_groups): - start, end = self.get_partition_fp32_group_index(i) - fp16_partitions[partition_id].data.copy_( - self.single_partition_of_fp32_groups[start:end].data) - return - - timers('optimizer_step').stop() - - timers('optimizer_allgather').start() - #gather the updated weights from everyone - for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): - - #Sequential AllGather Best of both worlds - dp_world_size = dist.get_world_size(group=self.dp_process_group) - num_shards = max( - 1, - partitioned_params[partition_id].numel() * dp_world_size // - self.allgather_bucket_size) - - shard_size = partitioned_params[partition_id].numel() // num_shards - num_elements = shard_size - - assert shard_size * num_shards <= partitioned_params[partition_id].numel() - - for shard_id in range(num_shards): - - if shard_id == (num_shards - 1): - num_elements = partitioned_params[partition_id].numel( - ) - shard_id * shard_size - - shard_list = [] - for dp_id in range(dp_world_size): - curr_shard = partitioned_params[dp_id].narrow( - 0, - shard_id * shard_size, - num_elements).detach() - shard_list.append(curr_shard) - - dist.all_gather(shard_list, - shard_list[partition_id], - group=self.dp_process_group) - timers('optimizer_allgather').stop() - - # TODO: we probably don't need this? just to be safe - for i in range(len(norm_groups)): - updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], - self.fp16_groups[i]) - for p, q in zip(self.fp16_groups[i], updated_params): - p.data = q.data - - see_memory_usage('After zero_optimizer step') - return - def step(self, closure=None): """ Not supporting closure. """ - if self.cpu_offload: - return self.step_with_cpuoffload() see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow @@ -1384,6 +1196,11 @@ def step(self, closure=None): partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] + #norm_groups.append( + # self.get_grad_norm_direct(self.averaged_gradients[i], + # self.params_in_partition[i])) norm_groups.append( self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) @@ -1400,7 +1217,7 @@ def step(self, closure=None): self.single_partition_of_fp32_groups[i].dtype) else: single_grad_partition = _flatten_dense_tensors( - self.averaged_gradients[i]).to( + self.averaged_gradients_on_cpu[i]).to( self.single_partition_of_fp32_groups[i].dtype) assert single_grad_partition.numel() == self.partition_size[i], \ "averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id) @@ -1409,20 +1226,26 @@ def step(self, closure=None): #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list(self.params_in_partition[i]) - self.averaged_gradients[i] = None + self.averaged_gradients[i] = notneeded single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) + self.cpu_computation_stream.wait_stream(self.migration_stream) timers('optimizer_step').start() - self.optimizer.step() - #get rid of the fp32 gradients. Not needed anymore - for group in self.single_partition_of_fp32_groups: - group.grad = None + with torch.cuda.stream(self.cpu_computation_stream): + self.optimizer.step() - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + if self.cpu_offload: + with torch.cuda.stream(self.migration_stream): + for averaged_gradients_cpu, fp32_partition in zip(self.averaged_gradients_on_cpu, self.single_partition_of_fp32_groups): + averaged_gradients_cpu[0] = async_copy_to(fp32_partition, torch.cuda.current_divice(),self.migration_stream) + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id] = async_copy_to(fp32_partition, torch.cuda.current_divice(),torch.cuda.main_stream()) + else: + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) timers('optimizer_step').stop() timers('optimizer_allgather').start() @@ -1492,21 +1315,6 @@ def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): else: grad.data.mul_(1. / combined_scale) - def unscale_and_clip_grads_on_cpu(self, grad_groups_flat, norm_groups): - total_norm = 0.0 - for norm in norm_groups: - total_norm += norm**2.0 - total_norm = math.sqrt(total_norm) - - # compute combined scale factor for this group - combined_scale = self.loss_scale - if self.clip_grad > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale - grad_groups_flat.data.mul_(1. / combined_scale) - def _check_overflow(self, partition_gradients=True): self.overflow = self.has_overflow(partition_gradients) @@ -1520,6 +1328,8 @@ def has_overflow_serial(self, params, is_grad_list=False): def has_overflow_partitioned_grads_serial(self): for i in range(len(self.fp16_groups)): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True @@ -1646,8 +1456,12 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - lean_length = value.numel() - padding - lean_state[key] = value[:lean_length] + #jie: torch.optim.Adam() has "step" has a key in state_dict + if key == "step": + lean_state[key] = value + else: + lean_length = value.numel() - padding + lean_state[key] = value[:lean_length] return lean_state @@ -1685,14 +1499,16 @@ def state_dict(self): state_dict['partition_count'] = self.partition_count # Remove paddings for DP alignment to enable loading for other alignment values - if self.cpu_offload: - fp32_groups_without_padding = self.single_partition_of_fp32_groups[ - 0:self.total_params] - else: - fp32_groups_without_padding = self._get_groups_without_padding( - self.single_partition_of_fp32_groups) + fp32_groups_without_padding = self._get_groups_without_padding( + self.single_partition_of_fp32_groups) state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding + + if self.cpu_offload: + state_dict_tmp = async_copy_to(state_dict, 'cpu', torch.cuda.current_stream()) + state_dict = None + state_dict = state_dict_tmp + return state_dict # Restore base optimizer fp32 weights from checkpoint by: diff --git a/deepspeed/pt/deepspeed_zero_utils.py b/deepspeed/pt/deepspeed_zero_utils.py new file mode 100755 index 000000000000..fd43605d3de1 --- /dev/null +++ b/deepspeed/pt/deepspeed_zero_utils.py @@ -0,0 +1,32 @@ +import torch +from torch.autograd import Variable +import collections + +def async_migrate_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + v = obj.cuda(dev, async=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + target = torch.empty_like(obj, device=dev).copy_(obj) + if main_stream is not None: + target.data.record_stream(main_stream) + return target + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj]