Skip to content

Commit

Permalink
【Auto Parallel】New local tensor (#38747)
Browse files Browse the repository at this point in the history
* update dist tensor

* add unitest

* update unitest

* refactor dist tensor

* update dist tensor and unitest
  • Loading branch information
Caozhou1995 authored Jan 11, 2022
1 parent fbb4028 commit d3ba189
Show file tree
Hide file tree
Showing 4 changed files with 523 additions and 3 deletions.
18 changes: 17 additions & 1 deletion python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(self, program=None):
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []

# Distributed programs
self._dist_main_programs = {}
self._dist_startup_programs = {}

@property
def serial_program(self):
return self._serial_program
Expand All @@ -84,6 +88,14 @@ def process_meshes(self):
def dist_op_context(self):
return self._dist_op_context

@property
def dist_main_programs(self):
return self._dist_main_programs

@property
def dist_startup_programs(self):
return self._dist_startup_programs

def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
Expand Down Expand Up @@ -371,10 +383,14 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph":
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))

# update dist tensor's dist_context
for key in result._dist_tensors_for_program.keys():
result._dist_tensors_for_program[key]._dist_context = result
return result


Expand Down
283 changes: 281 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,155 @@
# limitations under the License

import copy
import inspect

import paddle
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Block, Variable
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys
from .utils import _linear_idx2coordinate


class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None):
"""
DistributedTensor represents the distribution of tensor on the process group and
local tensors can be created by DistributedTensor.
Only support even sharding now and uneven sharding will be supported in the future.
Local tensor information can be obtained from the DistributedTensor instance object,
or obtained by the static methods provided by DistributedTensor,
including shard (i.e. the index in the serial tensor), offsets, and sizes.
"""

@staticmethod
def _validate_sizes_and_dist_attr(sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, sizes))):
raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}".
format(sizes))
if not (isinstance(dims_mapping, (list, tuple)) and all(
map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))):
raise ValueError(
"The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}".
format(dims_mapping))
if not (isinstance(processes, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x >= 0, processes))):
raise ValueError(
"The processes must be list or tuple and item in processes must be integer, but got {}".
format(processes))
if not (isinstance(topology, (list, tuple)) and
all(map(lambda x: isinstance(x, int) and x > 0, topology))):
raise ValueError(
"The topology must be list or tuple and item in topology must be non-negative integer, but got {}".
format(topology))
if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank))

# NOTE: Only support even sharding now
if shard_sizes is not None:
raise ValueError("Only support even sharding now.")

@staticmethod
def get_local_sizes(global_sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)

local_sizes = []
# for even sharding, the local sizes of every rank are equal
for idx, item in enumerate(global_sizes):
if dims_mapping[idx] == -1:
local_sizes.append(item)
else:
local_sizes.append(item // topology[dims_mapping[idx]])

return local_sizes

@staticmethod
def get_local_offsets(global_sizes,
dims_mapping,
topology,
processes,
rank,
shard_sizes=None):
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_offsets = []
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(topology, rank_relatvie)

for i in range(len(global_sizes)):
if dims_mapping[i] == -1:
local_offsets.append(0)
else:
local_offsets.append(coordinate[dims_mapping[i]] *
local_sizes[i])
return local_offsets

@staticmethod
def get_global_sizes(local_sizes,
dims_mapping,
topology,
processes,
rank=None,
shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr(
local_sizes, dims_mapping, topology, processes, rank, shard_sizes)
global_sizes = []
for idx, item in enumerate(local_sizes):
if dims_mapping[idx] == -1:
global_sizes.append(item)
else:
global_sizes.append(item * topology[dims_mapping[idx]])
return global_sizes

@staticmethod
def get_local_shard(global_sizes,
dims_mapping,
topology,
processes,
rank,
shard_sizes=None):
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
assert len(local_sizes) == len(
local_offsets
), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format(
len(local_sizes), len(local_offsets))

local_end_offsets = list(
map(lambda x: x[0] + x[1], zip(local_offsets, local_sizes)))
local_shard = list(zip(local_offsets, local_end_offsets))
return local_shard

def __init__(self, serial_tensor, dist_attr=None, dist_context=None):
self._serial_tensor = serial_tensor
self._dist_attr = None
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
self._local_sizes_map = {}
self._local_offsets_map = {}
self._local_shard_map = {}
self._local_tensor_map = {}

from .dist_context import get_default_distributed_context
self._dist_context = dist_context if dist_context is not None else get_default_distributed_context(
)
# TODO: Add Automatically to dist_context after initialized and it will be adapted in the future.
# self._dist_context.add_dist_tensor_for_program(self)

@property
def serial_tensor(self):
Expand All @@ -34,6 +171,10 @@ def serial_tensor(self):
def dist_attr(self):
return self._dist_attr

@property
def dist_context(self):
return self._dist_context

@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
Expand Down Expand Up @@ -66,12 +207,150 @@ def validate_dist_attr(self):
return False
return True

def local_sizes(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_sizes = None
if rank in self._local_sizes_map.keys():
local_sizes = self._local_sizes_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_sizes = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_sizes_map[rank] = local_sizes

return local_sizes

def local_offsets(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_offsets = None
if rank in self._local_offsets_map.keys():
local_offsets = self._local_offsets_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_offsets_map[rank] = local_offsets

return local_offsets

def global_sizes(self):
return self.serial_tensor.shape

def local_shard(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
local_shard = None
if rank in self._local_shard_map.keys():
local_shard = self._local_shard_map[rank]
else:
global_sizes = self.serial_tensor.shape
dims_mapping = self.dist_attr.dims_mapping
shard_sizes = self.dist_attr.shard_sizes
processes = self.dist_attr.process_mesh.processes
topology = self.dist_attr.process_mesh.topology
local_shard = DistributedTensor.get_local_shard(
global_sizes, dims_mapping, topology, processes, rank,
shard_sizes)
self._local_shard_map[rank] = local_shard

return local_shard

def new_local_tensor(self, block=None, rank=None, name=None):
"""
Create a new local tensor of serial tensor corresponding to rank.
Args:
block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None.
rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None.
"""

def _copy_kwargs(serial_tensor):
kwargs = {}
no_need_copy_args = ["self", "block", "shape", "name"]
arg_spec = inspect.getargspec(Variable.__init__)

for key in arg_spec.args:
# TODO: Check the copied attribute from serial tensor whether valid
if key in no_need_copy_args:
continue
elif key not in kwargs:
if key == "type":
kwargs[key] = serial_tensor.desc.type()
elif key == "dtype":
kwargs[key] = serial_tensor.desc.dtype()
elif key == "lod_level":
kwargs[key] = serial_tensor.desc.lod_level()
elif key == "persistable":
kwargs[key] = serial_tensor.desc.persistable()
elif key == "stop_gradient":
kwargs[key] = serial_tensor.desc.stop_gradient()
elif key == "need_check_feed":
kwargs[key] = serial_tensor.desc.need_check_feed()
# TODO: Get capacity by framework
elif key == "capacity":
continue
else:
kwargs[key] = self.serial_tensor.__dict__[key]

if isinstance(serial_tensor, Parameter):
kwargs["trainable"] = serial_tensor.trainable
kwargs["optimize_attr"] = serial_tensor.trainable
kwargs["regularizer"] = serial_tensor.regularizer
kwargs["do_model_average"] = serial_tensor.do_model_average
kwargs["need_clip"] = serial_tensor.need_clip
kwargs["is_distributed"] = serial_tensor.is_distributed
kwargs["is_parameter"] = serial_tensor.is_parameter

return kwargs

if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank))
if block is not None and not isinstance(block, Block):
raise TypeError("The block must be Block, but got {}.".format(
type(block)))
rank = paddle.distributed.get_rank() if rank is None else rank

if block is None:
block_id = self.serial_tensor.block.idx
block = self.dist_context.dist_main_programs[rank].block(block_id)

# copy serial tensor attribute
kwargs = _copy_kwargs(self.serial_tensor)
kwargs["name"] = name
kwargs["shape"] = self.local_sizes(rank)

if isinstance(self.serial_tensor, Parameter):
kwargs.pop("persistable")
local_tensor = Parameter(block=block, **kwargs)
else:
local_tensor = block.create_var(**kwargs)

# TODO: Set original id when set original_id is approved
local_tensor.desc.set_original_id(self.serial_tensor.desc.id())
self._local_tensor_map[rank] = local_tensor
return local_tensor

def local_tensor(self, rank=None):
rank = paddle.distributed.get_rank() if rank is None else rank
assert rank in self._local_tensor_map, "The rank {} local tensor has not been created.".format(
rank)
return self._local_tensor_map[rank]

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_tensor":
if k == "_serial_tensor" or k == "_local_tensor_map":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_dist_tensor)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp)
Expand Down Expand Up @@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_dist_tensor)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
Expand Down Expand Up @@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_dist_tensor MODULES test_auto_parallel_dist_tensor ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS})
Expand Down
Loading

0 comments on commit d3ba189

Please sign in to comment.