Skip to content

Commit

Permalink
[Auto Parallel] Logical Partition & Dist Op (#35117)
Browse files Browse the repository at this point in the history
* support shard reader

* support shard reader

* add parallel mode

* update process mesh

* add method to compute comm_group

* implement dist_embedding forward func

* implement dist matmul forward func

* implement dist reshape forward func

* add transpiler framework

* add transpiler forward

* implement transpiler forward

* implement transpiler backward & update

* add process

* add unitest

* chmod

* chmod

* chmod

* update unitest

* add unitest for gpt

* remove unused print

* rename transpiler --> partitioner

* rename transpiler --> partitioner

* chmod

* chmod

* bug fixed

* remove amp function

* update case for dp mode

* update case for dp mode
  • Loading branch information
JZ-LIANG authored Sep 2, 2021
1 parent 280d742 commit a622b70
Show file tree
Hide file tree
Showing 14 changed files with 3,515 additions and 3 deletions.
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
from collections import defaultdict
from paddle.fluid import core


class TensorDistributedAttribute:
Expand Down Expand Up @@ -77,6 +78,8 @@ def mark_as_parameter(self):
self._is_parameter = True

def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
Expand Down Expand Up @@ -222,6 +225,8 @@ def mark_as_parameter(self, name):
self._is_parameters[name] = True

def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/distributed/auto_parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map

# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
Expand Down Expand Up @@ -49,6 +51,20 @@ def __init__(self):
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1

def is_initialized_for_program(self):
return self._is_initialized_for_program
Expand Down Expand Up @@ -99,6 +115,19 @@ def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr

def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1

def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
Expand Down Expand Up @@ -377,3 +406,11 @@ def amend_distributed_attr_for_program(self):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1

def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh

def _get_model_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._model_parallel_axis, self._process_mesh
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def parent(self):
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]

@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)

def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
Expand Down Expand Up @@ -229,6 +236,13 @@ def set_placement(self, order):
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]

def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""

_g_process_mesh_map = dict()

def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_impls(self):
class DistributedOperatorImpl:
def __init__(self):
self._name = None
self._forward_implemented = False
self._backward_implemented = False

def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
Expand Down
112 changes: 112 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group


class DistributedEmbedding(DistributedOperator):
Expand All @@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False

def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
Expand Down Expand Up @@ -92,6 +100,110 @@ def update_dims_mapping(self, op_dist_attr):

return changed

def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
input_name_mapping['Ids'])
assert len(
input_name_mapping['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
input_name_mapping['W'])
assert len(
output_name_mapping['Out']
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
input_name_mapping['Out'])

Ids_var = dst_block.var(input_name_mapping['Ids'][0])
Weight_var = dst_block.var(input_name_mapping['W'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])

# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group

# caculate embedding offset
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
mesh_shape = len(process_mesh_shape)
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
if mesh_shape == 1:
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition

per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size

# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)

# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')

intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)

check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')

dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})

# use_model_parallel
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle


register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
Loading

0 comments on commit a622b70

Please sign in to comment.