diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index 2b62a66eba09..717947226165 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params( return shared_grads, expert_grads -def split_params_into_different_moe_groups_for_optimizer( - param_groups: Tuple[Dict]) -> Tuple[Dict]: +def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], + max_group_size=178956971 + ) -> Tuple[Dict]: """Split parameters into different MoE groups for optimizer Args: @@ -112,8 +113,32 @@ def split_params_into_different_moe_groups_for_optimizer( param_group['params'] = new_params # Flatten the moe groups - for k, v in group_moe.items(): - for k1, v1 in v.items(): - param_groups.append(v1) + if max_group_size is not None: + for k, v in group_moe.items(): + for k1, v1 in v.items(): + cur_group = [] + all_groups = [] + size_of_cur_group = 0 + for param in v1['params']: + if size_of_cur_group + param.numel() <= max_group_size: + cur_group.append(param) + size_of_cur_group += param.numel() + else: + all_groups.append(cur_group) + cur_group = [param] + size_of_cur_group = param.numel() + if cur_group: + all_groups.append(cur_group) + for group in all_groups: + new_dict = {} + for key, val in v1.items(): + if key != 'params': + new_dict[key] = val + new_dict['params'] = group + param_groups.append(new_dict) + else: + for k, v in group_moe.items(): + for k1, v1 in v.items(): + param_groups.append(v1) return tuple(param_groups) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 211e9dac7380..c019bd0e3647 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1653,6 +1653,44 @@ def override_loss_scale(self, loss_scale): self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def scaled_global_norm(self, norm_type=2): + assert norm_type == 2, "only L2 norm supported" + norm_groups = [] + for i, group in enumerate(self.bit16_groups): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + if self.cpu_offload: + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload( + self.params_in_partition[i])) + single_grad_partition = self.single_partition_of_fp32_groups[i].grad + else: + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.params_in_partition[i])) + + if self.has_moe_layers: + self._average_expert_grad_norms(norm_groups) + + # note that the get_global_norm function only supports l2 norm + return get_global_norm(norm_list=norm_groups) + + def get_bit16_param_group(self, group_no): + bit16_partitions = self.parallel_partitioned_bit16_groups[group_no] + partition_id = dist.get_rank(group=self.real_dp_process_group[group_no]) + return [ + bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])] + ] + + def _optimizer_step(self, group_no): + original_param_groups = self.optimizer.param_groups + self.optimizer.param_groups = [original_param_groups[group_no]] + from deepspeed.ops.adam import DeepSpeedCPUAdam + if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: + self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) + else: + self.optimizer.step() + self.optimizer.param_groups = original_param_groups + def step(self, closure=None): """ Not supporting closure. @@ -1671,7 +1709,6 @@ def step(self, closure=None): prev_scale = self.loss_scale self._update_scale(self.overflow) if self.overflow: - if dist.get_rank() == 0: logger.info( "[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " @@ -1692,22 +1729,33 @@ def step(self, closure=None): self.stop_timers(timer_names) return - self.start_timers([OPTIMIZER_GRADIENTS]) - norm_groups = [] - single_partition_grad_groups = [] - # skip = False + # Step 1:- Calculate gradient norm using fp-16 grads + see_memory_usage('Before norm calculation') + scaled_global_grad_norm = self.scaled_global_norm() + self._global_grad_norm = scaled_global_grad_norm / self.loss_scale + + see_memory_usage('After norm before optimizer') + # Step 2:- run optimizer and upscaling simultaneously for i, group in enumerate(self.bit16_groups): + self.start_timers([OPTIMIZER_GRADIENTS]) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: - norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload( - self.params_in_partition[i])) single_grad_partition = self.single_partition_of_fp32_groups[i].grad - else: - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.params_in_partition[i])) + self.unscale_and_clip_grads([single_grad_partition], + scaled_global_grad_norm) + self.stop_timers([OPTIMIZER_GRADIENTS]) + self.start_timers([OPTIMIZER_STEP]) + self._optimizer_step(i) + + from deepspeed.ops.adam import DeepSpeedCPUAdam + if not (type(self.optimizer) == DeepSpeedCPUAdam + and self.dtype == torch.half): + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] + bit16_partitions[partition_id].data.copy_(fp32_partition.data) + self.stop_timers([OPTIMIZER_STEP]) + else: # free gradients for all the parameters that are not updated by this process(ZeRO stage2) self.free_grad_in_param_list(self.params_not_in_partition[i]) @@ -1732,53 +1780,22 @@ def step(self, closure=None): self.averaged_gradients[i] = None - single_partition_grad_groups.append(single_grad_partition) - - if self.has_moe_layers: - self._average_expert_grad_norms(norm_groups) - - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) - self.unscale_and_clip_grads(single_partition_grad_groups, - scaled_global_grad_norm) - - # Stash unscaled gradient norm - self._global_grad_norm = scaled_global_grad_norm / self.loss_scale - - self.stop_timers([OPTIMIZER_GRADIENTS]) - - self.start_timers([OPTIMIZER_STEP]) - if self.deepspeed_adam_offload: - from deepspeed.ops.adam import DeepSpeedCPUAdam - if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: - bit16_param_groups = [ - [ - bit16_partitions[dist.get_rank( - group=self.real_dp_process_group[group_id])] - ] for group_id, - bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups) - ] - self.optimizer.step(fp16_param_groups=bit16_param_groups) - else: - self.optimizer.step() - for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): - partition_id = dist.get_rank( - group=self.real_dp_process_group[group_id]) - - bit16_partitions[partition_id].data.copy_(fp32_partition.data) - else: - self.optimizer.step() - - # get rid of the fp32 gradients. Not needed anymore - if not self.cpu_offload: - for group in self.single_partition_of_fp32_groups: - group.grad = None # in step - - for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): - partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) + self.unscale_and_clip_grads([single_grad_partition], + scaled_global_grad_norm) + self.stop_timers([OPTIMIZER_GRADIENTS]) + + # Step 3:- run the optimizer if no offloading + self.start_timers([OPTIMIZER_STEP]) + self._optimizer_step(i) + # Step 4:- get rid of the fp32 gradients. Not needed anymore + self.single_partition_of_fp32_groups[i].grad = None + del single_grad_partition + bit16_partitions = self.parallel_partitioned_bit16_groups[i] + fp32_partition = self.single_partition_of_fp32_groups[i] bit16_partitions[partition_id].data.copy_(fp32_partition.data) + self.stop_timers([OPTIMIZER_STEP]) - self.stop_timers([OPTIMIZER_STEP]) - + see_memory_usage('After optimizer before all-gather') if self.cpu_offload: self.reset_cpu_buffers() @@ -1794,7 +1811,7 @@ def step(self, closure=None): self.stop_timers([OPTIMIZER_ALLGATHER]) # TODO: we probably don't need this? just to be safe - for i in range(len(norm_groups)): + for i in range(len(self.bit16_groups)): self._update_model_bit16_weights(i) self.log_timers(timer_names)