Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions deepspeed/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params(
return shared_grads, expert_grads


def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict]) -> Tuple[Dict]:
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
Copy link
Contributor Author

@siddharth9820 siddharth9820 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase, please let me know how we would want to offer the user a way to set the max_group_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, let's add an moe section in ds_config. Perhaps, @awan-10 could help with the design.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awan-10 Perhaps the moe section can also contain a flag to toggle expert slicing too?

max_group_size=178956971
) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer

Args:
Expand Down Expand Up @@ -112,8 +113,32 @@ def split_params_into_different_moe_groups_for_optimizer(
param_group['params'] = new_params

# Flatten the moe groups
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)
if max_group_size is not None:
for k, v in group_moe.items():
for k1, v1 in v.items():
cur_group = []
all_groups = []
size_of_cur_group = 0
for param in v1['params']:
if size_of_cur_group + param.numel() <= max_group_size:
cur_group.append(param)
size_of_cur_group += param.numel()
else:
all_groups.append(cur_group)
cur_group = [param]
size_of_cur_group = param.numel()
if cur_group:
all_groups.append(cur_group)
for group in all_groups:
new_dict = {}
for key, val in v1.items():
if key != 'params':
new_dict[key] = val
new_dict['params'] = group
param_groups.append(new_dict)
else:
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)

return tuple(param_groups)
133 changes: 75 additions & 58 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,44 @@ def override_loss_scale(self, loss_scale):
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale

def scaled_global_norm(self, norm_type=2):
assert norm_type == 2, "only L2 norm supported"
norm_groups = []
for i, group in enumerate(self.bit16_groups):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))

if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)

# note that the get_global_norm function only supports l2 norm
return get_global_norm(norm_list=norm_groups)

def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
partition_id = dist.get_rank(group=self.real_dp_process_group[group_no])
return [
bit16_partitions[dist.get_rank(group=self.real_dp_process_group[group_no])]
]

def _optimizer_step(self, group_no):
original_param_groups = self.optimizer.param_groups
self.optimizer.param_groups = [original_param_groups[group_no]]
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)])
else:
self.optimizer.step()
self.optimizer.param_groups = original_param_groups

def step(self, closure=None):
"""
Not supporting closure.
Expand All @@ -1671,7 +1709,6 @@ def step(self, closure=None):
prev_scale = self.loss_scale
self._update_scale(self.overflow)
if self.overflow:

if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
Expand All @@ -1692,22 +1729,33 @@ def step(self, closure=None):
self.stop_timers(timer_names)
return

self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
# skip = False
# Step 1:- Calculate gradient norm using fp-16 grads
see_memory_usage('Before norm calculation')
scaled_global_grad_norm = self.scaled_global_norm()
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

see_memory_usage('After norm before optimizer')
# Step 2:- run optimizer and upscaling simultaneously
for i, group in enumerate(self.bit16_groups):
self.start_timers([OPTIMIZER_GRADIENTS])
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
if self.cpu_offload:
norm_groups.append(
self.complete_grad_norm_calculation_for_cpu_offload(
self.params_in_partition[i]))
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
else:
norm_groups.append(
self.get_grad_norm_direct(self.averaged_gradients[i],
self.params_in_partition[i]))
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)

from deepspeed.ops.adam import DeepSpeedCPUAdam
if not (type(self.optimizer) == DeepSpeedCPUAdam
and self.dtype == torch.half):
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)

self.stop_timers([OPTIMIZER_STEP])
else:
# free gradients for all the parameters that are not updated by this process(ZeRO stage2)
self.free_grad_in_param_list(self.params_not_in_partition[i])

Expand All @@ -1732,53 +1780,22 @@ def step(self, closure=None):

self.averaged_gradients[i] = None

single_partition_grad_groups.append(single_grad_partition)

if self.has_moe_layers:
self._average_expert_grad_norms(norm_groups)

scaled_global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(single_partition_grad_groups,
scaled_global_grad_norm)

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

self.stop_timers([OPTIMIZER_GRADIENTS])

self.start_timers([OPTIMIZER_STEP])
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
bit16_param_groups = [
[
bit16_partitions[dist.get_rank(
group=self.real_dp_process_group[group_id])]
] for group_id,
bit16_partitions in enumerate(self.parallel_partitioned_bit16_groups)
]
self.optimizer.step(fp16_param_groups=bit16_param_groups)
else:
self.optimizer.step()
for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(
group=self.real_dp_process_group[group_id])

bit16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()

# get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
for group in self.single_partition_of_fp32_groups:
group.grad = None # in step

for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)):
partition_id = dist.get_rank(group=self.real_dp_process_group[group_id])
self.unscale_and_clip_grads([single_grad_partition],
scaled_global_grad_norm)
self.stop_timers([OPTIMIZER_GRADIENTS])

# Step 3:- run the optimizer if no offloading
self.start_timers([OPTIMIZER_STEP])
self._optimizer_step(i)
# Step 4:- get rid of the fp32 gradients. Not needed anymore
self.single_partition_of_fp32_groups[i].grad = None
del single_grad_partition
bit16_partitions = self.parallel_partitioned_bit16_groups[i]
fp32_partition = self.single_partition_of_fp32_groups[i]
bit16_partitions[partition_id].data.copy_(fp32_partition.data)
self.stop_timers([OPTIMIZER_STEP])

self.stop_timers([OPTIMIZER_STEP])

see_memory_usage('After optimizer before all-gather')
if self.cpu_offload:
self.reset_cpu_buffers()

Expand All @@ -1794,7 +1811,7 @@ def step(self, closure=None):
self.stop_timers([OPTIMIZER_ALLGATHER])

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
for i in range(len(self.bit16_groups)):
self._update_model_bit16_weights(i)

self.log_timers(timer_names)
Expand Down