-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hybrid Parallel] Support dp & mp in dygraph (#32323)
* support dp & mp
- Loading branch information
Showing
14 changed files
with
572 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
from .hybrid_parallel_optimizer import HybridParallelOptimizer |
58 changes: 58 additions & 0 deletions
58
...n/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.optimizer import Optimizer | ||
from ...utils.hybrid_parallel_util import fused_allreduce_gradients | ||
from ...base.topology import ParallelMode | ||
from paddle.fluid.dygraph import base as imperative_base | ||
from paddle.fluid import framework | ||
from paddle.fluid.framework import Variable | ||
|
||
|
||
class HybridParallelOptimizer: | ||
def __init__(self, optimizer, hcg, strategy): | ||
self._inner_opt = optimizer | ||
self._strategy = strategy | ||
self._hcg = hcg | ||
self._is_mp = ( | ||
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) | ||
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) | ||
|
||
@imperative_base.no_grad | ||
@framework.dygraph_only | ||
def step(self): | ||
if self._is_mp and self._need_dp: | ||
fused_allreduce_gradients( | ||
list(self._inner_opt._parameter_list), self._hcg) | ||
self._inner_opt.step() | ||
|
||
@imperative_base.no_grad | ||
def minimize(self, | ||
loss, | ||
startup_program=None, | ||
parameters=None, | ||
no_grad_set=None): | ||
assert isinstance(loss, Variable), "The loss should be an Tensor." | ||
|
||
parameter_list = parameters if parameters \ | ||
else self._parameter_list | ||
|
||
if self._is_mp and self._need_dp: | ||
fused_allreduce_gradients(list(parameter_list), self._hcg) | ||
|
||
return self._inner_opt.minimize(loss, startup_program, parameters, | ||
no_grad_set) | ||
|
||
def __getattr__(self, item): | ||
return getattr(self._inner_opt, item) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ | |
# limitations under the License. | ||
|
||
from .mp_utils import * | ||
from .model_parallel import ModelParallel |
43 changes: 43 additions & 0 deletions
43
python/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.fluid.dygraph.layers import Layer | ||
import logging | ||
|
||
|
||
class MetaParallelBase(Layer): | ||
def __init__(self, layers, hcg, strategy): | ||
super(MetaParallelBase, | ||
self).__init__(layers.full_name() + "_meta_parallel_base") | ||
self._layers = layers | ||
self._hcg = hcg | ||
self._prepare_for_model() | ||
|
||
def _prepare_for_model(self): | ||
pass | ||
|
||
def _pre_forward(self, *inputs, **kwargs): | ||
pass | ||
|
||
def forward(self, *inputs, **kwargs): | ||
self._pre_forward(*inputs, **kwargs) | ||
|
||
output = self._layers(*inputs, **kwargs) | ||
|
||
self._post_forward(output) | ||
|
||
return output | ||
|
||
def _post_forward(self, output): | ||
pass |
29 changes: 29 additions & 0 deletions
29
python/paddle/distributed/fleet/meta_parallel/model_parallel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.fluid.dygraph.layers import Layer | ||
from .meta_parallel_base import MetaParallelBase | ||
from ..utils.hybrid_parallel_util import * | ||
|
||
|
||
class ModelParallel(MetaParallelBase): | ||
def __init__(self, layers, hcg, **kwargs): | ||
super(ModelParallel, self).__init__(layers, hcg, **kwargs) | ||
|
||
def _prepare_for_model(self): | ||
broadcast_mp_parameters(self._layers, self._hcg) | ||
broadcast_dp_parameters(self._layers, self._hcg) | ||
|
||
def _pre_forward(self, *inputs, **kwargs): | ||
return broadcast_input_data(self._hcg, *inputs, **kwargs) |
96 changes: 96 additions & 0 deletions
96
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import six | ||
import numpy as np | ||
import warnings | ||
|
||
from paddle import framework | ||
import paddle | ||
from paddle.fluid import core | ||
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, construct_groups | ||
from collections import OrderedDict | ||
|
||
|
||
def _apply_collective_grads(parameters, comm_group): | ||
grad_var_set = set() | ||
grad_vars = [] | ||
sparse_grad_vars = [] | ||
|
||
for param in parameters: | ||
if param.trainable and (param._grad_ivar() is not None): | ||
g_var = param._grad_ivar() | ||
assert not g_var._is_sparse( | ||
), "Now, it doesn't support sparse parameters" | ||
grad_vars.append(g_var) | ||
assert g_var not in grad_var_set | ||
grad_var_set.add(g_var) | ||
|
||
coalesced_grads_and_vars = construct_groups(grad_vars, 128 * 1024 * 1024) | ||
|
||
for coalesced_grad, _, _ in coalesced_grads_and_vars: | ||
# need to div nranks | ||
coalesced_grad = coalesced_grad / comm_group.nranks | ||
paddle.distributed.all_reduce(coalesced_grad, group=comm_group) | ||
|
||
_split_tensors(coalesced_grads_and_vars) | ||
|
||
|
||
def broadcast_input_data(hcg, *inputs, **kwargs): | ||
model_parallel_group = hcg.get_model_parallel_group() | ||
src_rank = hcg.get_model_parallel_group_src_rank() | ||
|
||
for input_ in inputs: | ||
if isinstance(input_, core.VarBase): | ||
with framework.no_grad(): | ||
paddle.distributed.broadcast( | ||
input_, | ||
src=src_rank, | ||
group=model_parallel_group, | ||
use_calc_stream=True) | ||
else: | ||
print("it doesn't support data type {}".format(type(input_))) | ||
|
||
for k, v in kwargs.items(): | ||
if isinstance(v, core.VarBase): | ||
with framework.no_grad(): | ||
paddle.distributed.broadcast( | ||
v, | ||
src=src_rank, | ||
group=model_parallel_group, | ||
use_calc_stream=True) | ||
kwargs[k] = v | ||
else: | ||
print("it doesn't support data type {}".format(type(v))) | ||
return inputs, kwargs | ||
|
||
|
||
def broadcast_mp_parameters(model, hcg): | ||
model_parallel_group = hcg.get_model_parallel_group() | ||
src_rank = hcg.get_model_parallel_group_src_rank() | ||
sync_params_buffers( | ||
model, model_parallel_group, src_rank, is_model_parallel=True) | ||
|
||
|
||
def broadcast_dp_parameters(model, hcg): | ||
data_parallel_group = hcg.get_data_parallel_group() | ||
src_rank = hcg.get_data_parallel_group_src_rank() | ||
sync_params_buffers( | ||
model, data_parallel_group, src_rank, is_model_parallel=False) | ||
|
||
|
||
def fused_allreduce_gradients(parameter_list, hcg): | ||
data_parallel_group = hcg.get_data_parallel_group() | ||
with framework.no_grad(): | ||
_apply_collective_grads(parameter_list, data_parallel_group) |
Oops, something went wrong.