diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 12bf14fcce5bd..b194bcc3de6b5 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -62,6 +62,10 @@ def __init__(self, program=None): self._dist_op_context = DistributedOperatorContext() self._process_meshes = [] + # Distributed programs + self._dist_main_programs = {} + self._dist_startup_programs = {} + @property def serial_program(self): return self._serial_program @@ -84,6 +88,14 @@ def process_meshes(self): def dist_op_context(self): return self._dist_op_context + @property + def dist_main_programs(self): + return self._dist_main_programs + + @property + def dist_startup_programs(self): + return self._dist_startup_programs + def add_process_mesh(self, process_mesh): assert isinstance(process_mesh, ProcessMesh), \ 'The type of dim_mapping must be ProcessMesh.' @@ -371,10 +383,14 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_program" or k == "_serial_graph": + if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) + + # update dist tensor's dist_context + for key in result._dist_tensors_for_program.keys(): + result._dist_tensors_for_program[key]._dist_context = result return result diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index f46c6e86d6870..5e3c852699ab6 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -13,18 +13,155 @@ # limitations under the License import copy +import inspect + +import paddle from paddle.fluid import core +from paddle.fluid.framework import Parameter, Block, Variable from .dist_attribute import TensorDistributedAttribute from .dist_attribute import get_tensor_dist_attr_field_keys +from .utils import _linear_idx2coordinate class DistributedTensor: - def __init__(self, serial_tensor, dist_attr=None): + """ + DistributedTensor represents the distribution of tensor on the process group and + local tensors can be created by DistributedTensor. + Only support even sharding now and uneven sharding will be supported in the future. + Local tensor information can be obtained from the DistributedTensor instance object, + or obtained by the static methods provided by DistributedTensor, + including shard (i.e. the index in the serial tensor), offsets, and sizes. + """ + + @staticmethod + def _validate_sizes_and_dist_attr(sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + if not (isinstance(sizes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, sizes))): + raise ValueError( + "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". + format(sizes)) + if not (isinstance(dims_mapping, (list, tuple)) and all( + map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))): + raise ValueError( + "The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}". + format(dims_mapping)) + if not (isinstance(processes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x >= 0, processes))): + raise ValueError( + "The processes must be list or tuple and item in processes must be integer, but got {}". + format(processes)) + if not (isinstance(topology, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, topology))): + raise ValueError( + "The topology must be list or tuple and item in topology must be non-negative integer, but got {}". + format(topology)) + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + + # NOTE: Only support even sharding now + if shard_sizes is not None: + raise ValueError("Only support even sharding now.") + + @staticmethod + def get_local_sizes(global_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + + local_sizes = [] + # for even sharding, the local sizes of every rank are equal + for idx, item in enumerate(global_sizes): + if dims_mapping[idx] == -1: + local_sizes.append(item) + else: + local_sizes.append(item // topology[dims_mapping[idx]]) + + return local_sizes + + @staticmethod + def get_local_offsets(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_offsets = [] + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(topology, rank_relatvie) + + for i in range(len(global_sizes)): + if dims_mapping[i] == -1: + local_offsets.append(0) + else: + local_offsets.append(coordinate[dims_mapping[i]] * + local_sizes[i]) + return local_offsets + + @staticmethod + def get_global_sizes(local_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + local_sizes, dims_mapping, topology, processes, rank, shard_sizes) + global_sizes = [] + for idx, item in enumerate(local_sizes): + if dims_mapping[idx] == -1: + global_sizes.append(item) + else: + global_sizes.append(item * topology[dims_mapping[idx]]) + return global_sizes + + @staticmethod + def get_local_shard(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + assert len(local_sizes) == len( + local_offsets + ), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format( + len(local_sizes), len(local_offsets)) + + local_end_offsets = list( + map(lambda x: x[0] + x[1], zip(local_offsets, local_sizes))) + local_shard = list(zip(local_offsets, local_end_offsets)) + return local_shard + + def __init__(self, serial_tensor, dist_attr=None, dist_context=None): self._serial_tensor = serial_tensor self._dist_attr = None self._batch_dim = 0 # Reuse the dist_attr setter to initialize _dist_attr self.dist_attr = dist_attr + self._local_sizes_map = {} + self._local_offsets_map = {} + self._local_shard_map = {} + self._local_tensor_map = {} + + from .dist_context import get_default_distributed_context + self._dist_context = dist_context if dist_context is not None else get_default_distributed_context( + ) + # TODO: Add Automatically to dist_context after initialized and it will be adapted in the future. + # self._dist_context.add_dist_tensor_for_program(self) @property def serial_tensor(self): @@ -34,6 +171,10 @@ def serial_tensor(self): def dist_attr(self): return self._dist_attr + @property + def dist_context(self): + return self._dist_context + @dist_attr.setter def dist_attr(self, dist_attr): if self._dist_attr is None: @@ -66,12 +207,150 @@ def validate_dist_attr(self): return False return True + def local_sizes(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_sizes = None + if rank in self._local_sizes_map.keys(): + local_sizes = self._local_sizes_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_sizes_map[rank] = local_sizes + + return local_sizes + + def local_offsets(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_offsets = None + if rank in self._local_offsets_map.keys(): + local_offsets = self._local_offsets_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_offsets_map[rank] = local_offsets + + return local_offsets + + def global_sizes(self): + return self.serial_tensor.shape + + def local_shard(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_shard = None + if rank in self._local_shard_map.keys(): + local_shard = self._local_shard_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_shard_map[rank] = local_shard + + return local_shard + + def new_local_tensor(self, block=None, rank=None, name=None): + """ + Create a new local tensor of serial tensor corresponding to rank. + + Args: + block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None. + rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None. + """ + + def _copy_kwargs(serial_tensor): + kwargs = {} + no_need_copy_args = ["self", "block", "shape", "name"] + arg_spec = inspect.getargspec(Variable.__init__) + + for key in arg_spec.args: + # TODO: Check the copied attribute from serial tensor whether valid + if key in no_need_copy_args: + continue + elif key not in kwargs: + if key == "type": + kwargs[key] = serial_tensor.desc.type() + elif key == "dtype": + kwargs[key] = serial_tensor.desc.dtype() + elif key == "lod_level": + kwargs[key] = serial_tensor.desc.lod_level() + elif key == "persistable": + kwargs[key] = serial_tensor.desc.persistable() + elif key == "stop_gradient": + kwargs[key] = serial_tensor.desc.stop_gradient() + elif key == "need_check_feed": + kwargs[key] = serial_tensor.desc.need_check_feed() + # TODO: Get capacity by framework + elif key == "capacity": + continue + else: + kwargs[key] = self.serial_tensor.__dict__[key] + + if isinstance(serial_tensor, Parameter): + kwargs["trainable"] = serial_tensor.trainable + kwargs["optimize_attr"] = serial_tensor.trainable + kwargs["regularizer"] = serial_tensor.regularizer + kwargs["do_model_average"] = serial_tensor.do_model_average + kwargs["need_clip"] = serial_tensor.need_clip + kwargs["is_distributed"] = serial_tensor.is_distributed + kwargs["is_parameter"] = serial_tensor.is_parameter + + return kwargs + + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + if block is not None and not isinstance(block, Block): + raise TypeError("The block must be Block, but got {}.".format( + type(block))) + rank = paddle.distributed.get_rank() if rank is None else rank + + if block is None: + block_id = self.serial_tensor.block.idx + block = self.dist_context.dist_main_programs[rank].block(block_id) + + # copy serial tensor attribute + kwargs = _copy_kwargs(self.serial_tensor) + kwargs["name"] = name + kwargs["shape"] = self.local_sizes(rank) + + if isinstance(self.serial_tensor, Parameter): + kwargs.pop("persistable") + local_tensor = Parameter(block=block, **kwargs) + else: + local_tensor = block.create_var(**kwargs) + + # TODO: Set original id when set original_id is approved + local_tensor.desc.set_original_id(self.serial_tensor.desc.id()) + self._local_tensor_map[rank] = local_tensor + return local_tensor + + def local_tensor(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + assert rank in self._local_tensor_map, "The rank {} local tensor has not been created.".format( + rank) + return self._local_tensor_map[rank] + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_tensor": + if k == "_serial_tensor" or k == "_local_tensor_map": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 64c247e56d1d3..b46a10c8c79d8 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_dist_tensor) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_dist_tensor) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_dist_tensor MODULES test_auto_parallel_dist_tensor ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py new file mode 100644 index 0000000000000..b21cbb5ae78bc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import paddle +from paddle.fluid import core +import paddle.distributed.auto_parallel as auto +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute +import test_auto_parallel_reshard +from test_auto_parallel_reshard import mlp_forward + + +def get_dist_prog(train_program, + startup_program, + dist_context, + rank_id, + complete_train_program=None): + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion + complete_train_program = auto.complete_annotation( + train_program, dist_context + ) if complete_train_program is None else complete_train_program + + # parallelizer._apply_serial_forward_pass(complete_train_program, + # startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + + # logical partition + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + + return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program + + +class TestDistributedTensor(unittest.TestCase): + def test_new_local_tensor(self): + test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( + mesh=[0, 1]) + test_auto_parallel_reshard._global_parallel_strategy = "dp" + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 0 + dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_0 = dist_tensor.new_local_tensor( + name="intermediate_var_0") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_0.name, "intermediate_var_0") + + rank_id = 1 + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_main_prog, dist_startup_prog, _ = get_dist_prog( + train_program, startup_program, dist_context, rank_id, + complete_train_program) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="intermediate_var_1") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_1.name, "intermediate_var_1") + + name = "linear_0.w_0" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="linear_0.w_0_intermediate") + self.assertEqual(intermediate_var_1.shape, (1024, 4096)) + self.assertEqual(intermediate_var_1.name, "linear_0.w_0_intermediate") + + copied_dist_context = copy.deepcopy(dist_context) + self.assertIsNotNone(copied_dist_context) + self.assertEqual( + id(copied_dist_context), + id( + copied_dist_context.get_dist_tensor_for_program( + dist_tensor.serial_tensor).dist_context)) + + def test_static_method(self): + dims_mapping = [1, 0] + processes = [0, 1, 2, 3, 4, 5, 6] + topology = [2, 3] + global_sizes = [6, 6] + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + + rank = 1 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + # global sizes + local_sizes = [2, 3] + global_sizes = DistributedTensor.get_global_sizes( + local_sizes, dims_mapping, topology, processes) + self.assertEqual(global_sizes, [6, 6]) + + def test_instance_method(self): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = [1, 0] + tensor_dist_attr.process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2], [3, 4, 5]]) + serial_tensor = paddle.static.data( + name="data", shape=[6, 6], dtype='float32') + dist_tensor = DistributedTensor(serial_tensor, tensor_dist_attr) + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + self.assertEqual(local_sizes, dist_tensor.local_sizes(rank)) + self.assertEqual(local_offsets, dist_tensor.local_offsets(rank)) + self.assertEqual(local_shard, dist_tensor.local_shard(rank)) + self.assertEqual(local_sizes, dist_tensor.local_sizes()) + self.assertEqual(local_offsets, dist_tensor.local_offsets()) + self.assertEqual(local_shard, dist_tensor.local_shard()) + + rank = 1 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + global_sizes = dist_tensor.global_sizes() + self.assertEqual(global_sizes, (6, 6)) + + +if __name__ == "__main__": + unittest.main()