Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick for sharding #47061

Merged
merged 3 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import copy
import logging
import warnings

import numpy as np
from collections import OrderedDict

Expand Down Expand Up @@ -86,6 +88,11 @@ def __init__(self,
# Default information
self._optim = optim

# sharing stage 2 comm overlap flag
self._reduce_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None

assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"

Expand All @@ -103,6 +110,17 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

self._broadcast_overlap = False
self._forward_pre_hook_remove_helper = []
try:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self._broadcast_order_params = sorted(
self.local_params,
key=lambda x: int(x.name.split('.')[0].split('_')[-1]))
except ValueError:
self._broadcast_order_params = None

self._group = new_group(
_get_global_group().ranks) if group is None else group

Expand Down Expand Up @@ -157,6 +175,60 @@ def _sync_params_and_buffers(self):
group=self._group,
sync_op=True)

def _update_task(self, task):
if self._reduce_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task

def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap

def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
assert layers is not None, \
"To enable broadcast overlap forward, please pass the module to the function."
self._layers = layers
warnings.warn(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if self._broadcast_order_params is None:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings.warn(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params

if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1

assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"

self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group

ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)

def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
Expand Down Expand Up @@ -364,6 +436,13 @@ def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
hook_remove.remove()
self._forward_pre_hook_remove_helper = []

if self.offload:
params_list = [self.offload_params.buffer]
Expand Down Expand Up @@ -408,9 +487,52 @@ def _broadcast_params(self):
"""Broadcast the parameters of the current rank to each rank"""

# Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)
if self._broadcast_overlap:
self._broadcast_params_overlap_forward()
else:
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)

def _forward_pre_hook_function(self, tasks):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def __impl__(x, y):
for task in tasks:
# Wait for broadcast task before using the result of the broadcast.
task.wait()

return __impl__

@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
group_idx = 0

param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
group = self._broadcast_groups[group_idx]
group_idx = (group_idx + 1) % self._number_of_broadcast_groups
task = broadcast(tensor=x,
src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task

for layer in self._layers.sublayers():
if len(layer.sublayers()) == 0:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks = []
for param in layer.parameters():
if param.trainable:
if param.name in param2task:
tasks.append(param2task[param.name])
self._forward_pre_hook_remove_helper.append(
layer.register_forward_pre_hook(
self._forward_pre_hook_function(tasks)))
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def __init__(
for optim in self._sharding_optimizers:
self._all_params.extend(list(optim.local_params))

# sharing stage 2 comm overlap flag
self._reduce_overlap = False

self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
Expand Down Expand Up @@ -306,6 +309,18 @@ def _clear_counters(self):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()

def _set_reduce_overlap(self, reduce_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_reduce_overlap(True)
self._reduce_overlap = reduce_overlap
if self._reduce_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)

def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
Expand Down Expand Up @@ -337,11 +352,12 @@ def cleanup():
del tmp_grad
param.clear_gradient(False)

# Synchronize the reduce parameter gradient
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group,
sync_op=not self._reduce_overlap))

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
Expand Down Expand Up @@ -385,12 +401,13 @@ def cleanup():

# Reduce the bucket
grad_storage.sent = True
# Synchronize the reduce parameter gradient
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group,
sync_op=not self._reduce_overlap))

cleanup()

Expand Down Expand Up @@ -528,6 +545,10 @@ def _redefine_opt_step(self):
opt_step = opt.step

def _opt_step(self):
if self._reduce_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
grad_func()
opt_step()

Expand Down
Loading