Skip to content

Commit

Permalink
add_sharding_api (#40129)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan authored Mar 9, 2022
1 parent 1defc8f commit f40ed5f
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 24 deletions.
1 change: 1 addition & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from . import cloud_utils # noqa: F401
from . import utils # noqa: F401

from .sharding import * # noqa: F401

__all__ = [ # noqa
"spawn",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
Type.fp32.value: 4,
}

__all__ = ["ShardingOptimizerStage2"]


class ShardingOptimizerStage2(Optimizer):
"""
Expand Down Expand Up @@ -136,7 +134,7 @@ def __init__(self,
# Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status()

@paddle.no_grad()
@paddle.autograd.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
Expand Down Expand Up @@ -392,7 +390,7 @@ def _clear_cache(self):
self._dtype_rank_params.clear()
self._param2rank.clear()

@fluid.dygraph.no_grad
@paddle.autograd.no_grad()
def _broadcast_params(self):
"""Broadcast the parameters of the current rank to each rank"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def __init__(
sync_buffers=False,
buffer_max_size=2**23, #8MB
auto_refresh_trainable=True,
device="gpu",
use_grad_storage=True):
device="gpu"):
super().__init__()

# training options
Expand Down Expand Up @@ -102,9 +101,10 @@ def __init__(
# Set grad storage size & Display param sizes and model sizes
model_size = sum(
[np.prod(p.shape) for p in self._layer.parameters()]).item()
assert buffer_max_size >= 0, "buffer_max_size must be GE than 0."
self._buffer_max_size = self._rank_buffer_size(buffer_max_size,
model_size)
self._use_grad_storage = use_grad_storage
self._use_grad_storage = buffer_max_size > 0
self._grad_storages = {} # {dtype: {rank: GradStorage}}
self._has_grad_storage = []
self._grad_storage_list = []
Expand Down Expand Up @@ -255,7 +255,7 @@ def _fresh_trainable(self):
# wait next func hook support
self._setup_backward_hooks()

@paddle.no_grad()
@paddle.autograd.no_grad()
def __sync_buffers(self):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
Expand All @@ -277,7 +277,7 @@ def __getattr__(self, name):
except AttributeError:
return getattr(self._layer, name)

@paddle.no_grad()
@paddle.autograd.no_grad()
def _clear_counters(self):
"""Reset all the grad reduce and call counters."""
if self.training:
Expand All @@ -290,13 +290,13 @@ def _clear_counters(self):
def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
- 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately.
- 1. Do not use self._use_grad_storage or exceeded buffer_max_size will be reduced separately.
- 2. Use grad_storage Reduce the storage to get the full gradient from different ranks.
"""

if not self._use_grad_storage or not self._has_grad_storage[index]:
# Direct reduction
@paddle.no_grad()
@paddle.autograd.no_grad()
def reduce(*_):
# Skip gradient reduction, do not change status information
if self._grad_reduced[index]:
Expand Down Expand Up @@ -336,7 +336,7 @@ def cleanup():

else:
# Buffer reduction
@paddle.no_grad()
@paddle.autograd.no_grad()
def reduce(*_):
# Skip gradient reduction, do not change status information
if self._grad_reduced[index]:
Expand Down Expand Up @@ -421,9 +421,6 @@ def _setup_use_grad_storage(self):
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
"""

if not self._use_grad_storage:
return

# According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank
self._grad_storages = {}
self._has_grad_storage = [False for _ in self._trainable_params]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self,
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
assert segment_size >= 0, "segment_size must be GE than 0."
self._segment_size = segment_size

global DEV
Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(self,
self._redefine_opt_step()
self._redefine_opt_clear()

@paddle.no_grad()
@paddle.autograd.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
Expand Down Expand Up @@ -408,7 +409,7 @@ def _forward_post_hook(layer, inputs, outputs):
# register post forward hooks
sub_layer.register_forward_post_hook(_forward_post_hook)

@paddle.no_grad()
@paddle.autograd.no_grad()
def _sync_buffers(self):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
Expand Down Expand Up @@ -521,7 +522,7 @@ def _register_backward_hooks(self):
param._register_backward_hook(allreduce_function)

def _get_allreduce_fn(self, param):
@paddle.no_grad()
@paddle.autograd.no_grad()
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
Expand Down Expand Up @@ -840,7 +841,7 @@ def _allgather_buffer(trainable_params,
return task_flow


@paddle.no_grad()
@paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params:
if param.name in task_flow.full_grad.keys():
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/distributed/sharding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .group_sharded import group_sharded_parallel, save_group_sharded_model # noqa: F401

__all__ = ['group_sharded_parallel', 'save_group_sharded_model']
211 changes: 211 additions & 0 deletions python/paddle/distributed/sharding/group_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging
from enum import Enum

import paddle

from paddle.optimizer import Optimizer
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

logger_ = get_logger(logging.INFO)


def group_sharded_parallel(model,
optimizer,
level,
scaler=None,
group=None,
offload=False,
sync_buffers=False,
buffer_max_size=2**23,
segment_size=2**20,
sync_comm=False):
"""
Use this module to configure and wrap up the parameters of the group shared module.
Args:
model (Layer): The layer to be wrapped with group_sharded_parallel.
optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel.
level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`.
scaler (GradScaler, optional): The scaler to be wrapped with group_sharded_parallel. Defaults to None.
group (Group, optional): The group instance. Defaults to None.d
offload (bool, optional): Whether to perform optimizer state and gradient transfer CPU. Defaults to False.
sync_buffers (bool, optional): Whether to broadcast model buffers. Defaults to False.
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. Defaults to 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False.
Returns:
model: A wrapper for group sharded given model.
optimizer: A wrapper for group sharded given optimizer.
scaler: A wrapper for group sharded given scaler.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
"""
# check optition type
assert isinstance(
model,
paddle.nn.Layer), "The model must be the instance of paddle.nn.Layer."
assert isinstance(
optimizer, Optimizer
), "The optimizer must be the instance of paddle.optimizer.Optimizer."
assert level in ['os', 'os_g', 'p_g_os'
], "The level must be os, os_g or p_g_os."

def check_dtype(param):
return param.dtype == paddle.float16

params_fp16 = filter(check_dtype, model.parameters())
if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.")
# convert model/optimizer/scaler
if level in ['os', 'os_g']:
logger_.info("*" * 30)
logger_.info("Sharded level os uses sharded level os_g achieved now.")
logger_.info("*" * 30)
optimizer = ShardingOptimizerStage2(
params=model.parameters(),
optim=optimizer,
group=group,
offload=offload)
model = ShardingStage2(
model,
optimizer,
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size)
elif level == 'p_g_os':
model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
sync_buffers=sync_buffers,
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm)
else:
raise ValueError("Please enter the correct level.")
if params_fp16 and isinstance(scaler, paddle.amp.GradScaler):
scaler = ShardingScaler(scaler)
logger_.info("*" * 30)
logger_.info(
"If there is a communication hang using group sharded, please check whether the communication operations of each process are unified."
)
logger_.info("*" * 30)

return model, optimizer, scaler


def save_group_sharded_model(model, output, optimizer=None):
"""
Group sharded encapsulated model and optimizer state saving module.
Args:
model (Layer): A wrapper for group sharded given model.
output (str): Save directory.
optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
# save model and optimizer state_dict
save_group_sharded_model(model, optimizer,output=output_dir)
"""
logger_.info(
"==========Begin to save group sharded model and optimizer==========")
assert not os.path.isfile(
output
), "Saving directory ({}) should be a directory, not a file".format(output)
os.makedirs(output, exist_ok=True)
output_model = os.path.join(output, "model.pdmodel")
if isinstance(model, ShardingStage2):
paddle.save(model._layer.state_dict(), output_model)
elif isinstance(model, ShardingStage3):
convert2cpu = True if model._offload else False
model.get_all_parameters(convert2cpu=convert2cpu)
paddle.save(model._layer.state_dict(), output_model)
else:
raise ValueError(
"Please use the layer which is wrapped with group_sharded_parallel.")

if optimizer is not None:
assert hasattr(
optimizer, "_optim"
), "Please use the optimizer which is wrapped with group_sharded_parallel."
output_opt = os.path.join(output, "model.pdopt")
paddle.save(optimizer._optim.state_dict(), output_opt)
logger_.info(
"==========End to save group sharded model and optimizer==========")
Loading

0 comments on commit f40ed5f

Please sign in to comment.