diff --git a/python/paddle/distributed/auto_parallel/mapper.py b/python/paddle/distributed/auto_parallel/mapper.py new file mode 100644 index 0000000000000..f015cf4477195 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/mapper.py @@ -0,0 +1,294 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import operator +import functools +import json +import paddle +from collections import deque +from .graph import Node +from .graph import Edge +from .graph import Graph +from .cluster import DeviceType +from .process_group import get_process_group + + +def is_collective_comm_op(op): + comm_list = [ + "c_allreduce_sum", "c_allreduce_min", "c_allreduce_max", + "c_allreduce_prod", "c_reduce_sum", "c_reduce_min", "c_reduce_max", + "c_reduce_prod", "c_broadcast", "c_allgather" + ] + if op.type in comm_list: + return True + else: + return False + + +def is_p2p_comm_op(op): + comm_list = ["send_v2", "recv_v2"] + if op.type in comm_list: + return True + else: + return False + + +def get_dtype_bytes(dtype): + num_bytes = 0 + if dtype == paddle.float64: + num_bytes = 8 + elif dtype == paddle.float32: + num_bytes = 4 + elif dtype == paddle.float16: + num_bytes = 2 + elif dtype == paddle.bfloat16: + num_bytes = 2 + elif dtype == paddle.int64: + num_bytes = 8 + elif dtype == paddle.int32: + num_bytes = 4 + elif dtype == paddle.int16: + num_bytes = 2 + elif dtype == paddle.int8: + num_bytes = 1 + elif dtype == paddle.uint8: + num_bytes = 1 + else: + raise ValueError("Unrecognized dtype {}.".format(dtype)) + return num_bytes + + +def get_comm_volume(comm_op, src_rank, tgt_rank): + comm_volume = None + if src_rank == tgt_rank: + return comm_volume + comm_op_type = comm_op.type + if comm_op_type != "recv_v2": + tensor_name = comm_op.input_arg_names[0] + else: + tensor_name = comm_op.output_arg_names[0] + tensor = comm_op.block._find_var_recursive(tensor_name) + assert tensor is not None + tensor_shape = tensor.shape + # Skip the batch dim + new_tensor_shape = [] + for val in tensor_shape: + if val == -1: + print("Warning: -1 in the tensor shape.") + new_tensor_shape.append(1) + else: + new_tensor_shape.append(val) + tensor_size = functools.reduce(operator.mul, new_tensor_shape, 1) + tensor_bytes = tensor_size * get_dtype_bytes(tensor.dtype) + if "c_allreduce" in comm_op_type: + comm_volume = 2 * tensor_bytes + elif "c_allgather" in comm_op_type: + comm_volume = tensor_bytes + elif "c_broadcast" in comm_op_type: + if comm_op.attr("root") == src_rank: + comm_volume = tensor_bytes + else: + comm_volume = None + elif "c_reduce" in comm_op_type: + if comm_op.attr("root_id") == src_rank: + comm_volume = None + else: + comm_volume = tensor_bytes + elif "send_v2" in comm_op_type: + if comm_op.attr("peer") == tgt_rank: + comm_volume = tensor_bytes + else: + comm_volume = None + elif "recv_v2" in comm_op_type: + comm_volume = None + else: + raise ValueError("Unrecognized communication operator.") + return comm_volume + + +def analyze_comm_requirements_from_op(op, rank): + comm_requirements_to_ranks = {} + if is_collective_comm_op(op): + process_group_id = op.attr("ring_id") + process_group = get_process_group(process_group_id) + if rank not in process_group.ranks: + return comm_requirements_to_ranks + for tgt_rank in process_group.ranks: + comm_volume = get_comm_volume(op, rank, tgt_rank) + if comm_volume is not None: + comm_requirements_to_ranks[tgt_rank] = {} + comm_requirements_to_ranks[tgt_rank][ + "comm_volume"] = comm_volume + elif is_p2p_comm_op(op): + tgt_rank = op.attr("peer") + comm_volume = get_comm_volume(op, rank, tgt_rank) + if comm_volume is not None: + comm_requirements_to_ranks[tgt_rank] = {} + comm_requirements_to_ranks[tgt_rank]["comm_volume"] = comm_volume + else: + comm_requirements_to_ranks = {} + return comm_requirements_to_ranks + + +def analyze_requirements_for_program(program, rank): + resource_requirements = {} + comm_requirements_to_ranks = {} + # only support device_type and only support GPU for now + resource_requirements["device_type"] = DeviceType.GPU + for block in program.blocks: + for op in block.ops: + cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op( + op, rank) + for tgt_rank, link_info in cur_comm_requirements_to_ranks.items(): + if tgt_rank in comm_requirements_to_ranks: + comm_requirements_to_ranks[tgt_rank][ + "comm_volume"] += link_info["comm_volume"] + else: + comm_requirements_to_ranks[tgt_rank] = {} + comm_requirements_to_ranks[tgt_rank][ + "comm_volume"] = link_info["comm_volume"] + return resource_requirements, comm_requirements_to_ranks + + +def build_process_graph(distributed_program): + graph = Graph() + for src_rank, src_program in distributed_program.items(): + resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program( + src_program, src_rank) + graph.add_node(src_rank, resource_requirements=resource_requirements) + for tgt_rank, comm_requirements in comm_requirements_to_ranks.items(): + graph.add_edge( + src_rank, tgt_rank, comm_requirements=comm_requirements) + return graph + + +def build_cluster_graph(cluster): + graph = Graph() + for machine in cluster.machines.values(): + for device in machine.devices.values(): + graph.add_node(device.global_id, device=device) + for link in machine.links.values(): + graph.add_edge( + link.source.global_id, link.target.global_id, link=link) + return graph + + +def mapping(distributed_program, cluster): + # A very simple mapping algorithm only for GPUs. + # Here we assume one process will be mapped to one GPU. + # In the future, more mapping configurations and algorithms will be supported. + process_graph = build_process_graph(distributed_program) + + cluster_graph = build_cluster_graph(cluster) + + for cur_rank_node in process_graph: + cur_rank_node["visited"] = False + + for cur_device_node in cluster_graph: + cur_device_node["occupied"] = False + + def sort_by_comm_volume(rank_edge): + return rank_edge["comm_requirements"]["comm_volume"] + + def sort_by_comm_bandwidth(device_edge): + return device_edge["link"].bandwidth + + def select_unvisited_rank_node(rank_node_list): + selected_rank_node = None + for rank_node in rank_node_list: + if rank_node["visited"] is False: + selected_rank_node = rank_node + return selected_rank_node + + queue = deque() + root_rank_node = select_unvisited_rank_node( + list(process_graph.nodes.values())) + while root_rank_node is not None: + queue.append(root_rank_node) + while queue: + cur_rank_node = queue.popleft() + if cur_rank_node["visited"]: + continue + device_type = cur_rank_node["resource_requirements"]["device_type"] + cur_device_node = None + for device_node in cluster_graph.nodes.values(): + if (device_node["device"].type == device_type) and ( + not device_node["occupied"]): + device_node["occupied"] = True + cur_rank_node["visited"] = True + cur_rank_node["device"] = device_node["device"] + cur_device_node = device_node + break + assert cur_device_node, "Cannot find a device to satisfy the requirement." + + nbr_rank_edges = [] + for nbr_rank_node_id, nbr_rank_edge in process_graph.adjs[ + cur_rank_node.id].items(): + assert nbr_rank_edge.src_id == cur_rank_node.id and nbr_rank_edge.tgt_id == nbr_rank_node_id + queue.append(process_graph.nodes[nbr_rank_node_id]) + nbr_rank_edges.append(nbr_rank_edge) + nbr_rank_edges.sort(key=sort_by_comm_volume) + + nbr_device_edges = [] + for nbr_device_edge in cluster_graph.adjs[ + cur_device_node.id].values(): + nbr_device_edges.append(nbr_device_edge) + nbr_device_edges.sort(key=sort_by_comm_bandwidth) + + for nbr_rank_edge in nbr_rank_edges: + src_rank_node = process_graph.nodes[nbr_rank_edge.src_id][ + "visited"] + if src_rank_node: + continue + device_type = src_rank_node["resource_requirements"][ + "device_type"] + nbr_rank_node = process_graph.nodes[nbr_rank_edge.tgt_id] + for nbr_device_edge in nbr_device_edges: + nbr_device_node = cluster_graph.nodes[ + nbr_device_edge.tgt_id] + if (nbr_device_node["device"].type == device_type) and ( + not nbr_device_node["occupied"]): + nbr_device_node["occupied"] = True + nbr_rank_node["visited"] = True + nbr_rank_node["device"] = nbr_device_node["device"] + break + root_rank_node = select_unvisited_rank_node( + list(process_graph.nodes.values())) + + rank_mapping = {} + for rank, rank_node in process_graph.nodes.items(): + device = rank_node["device"] + machine = device.machine + if machine.id in rank_mapping: + rank_mapping[machine.id]["hostname"] = machine.hostname + rank_mapping[machine.id]["addr"] = machine.addr + rank_mapping[machine.id]["port"] = machine.port + if rank not in rank_mapping[machine.id]["ranks"]: + rank_mapping[machine.id]["ranks"][rank] = [] + rank_mapping[machine.id]["ranks"][rank].append(device.local_id) + else: + rank_mapping[machine.id]["ranks"][rank].append(device.local_id) + else: + rank_mapping[machine.id] = {} + rank_mapping[machine.id]["hostname"] = machine.hostname + rank_mapping[machine.id]["addr"] = machine.addr + rank_mapping[machine.id]["port"] = machine.port + rank_mapping[machine.id]["ranks"] = {} + rank_mapping[machine.id]["ranks"][rank] = [] + rank_mapping[machine.id]["ranks"][rank].append(device.local_id) + for machine_mapping in rank_mapping.values(): + for rank_devices in machine_mapping["ranks"].values(): + rank_devices.sort() + + return rank_mapping diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 79567f438b695..099dadd617390 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -144,6 +144,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) endif() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py new file mode 100644 index 0000000000000..7b60a9753bd6d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -0,0 +1,609 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import os +import json +import collections +import math +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.nn.functional as F +import paddle.tensor as tensor +import paddle.utils as utils +import paddle.static as static +from paddle.fluid import core +from paddle.fluid import layers +from paddle.fluid.framework import in_dygraph_mode +from paddle.nn.layer.transformer import _convert_param_attr_to_list +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.distributed import fleet + +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.process_group import get_all_process_groups +from paddle.distributed.auto_parallel.process_group import new_process_group +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.cluster import DeviceType +from paddle.distributed.auto_parallel.cluster import LinkType +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.mapper import build_process_graph +from paddle.distributed.auto_parallel.mapper import build_cluster_graph +from paddle.distributed.auto_parallel.mapper import mapping +from paddle.distributed.auto_parallel.mapper import get_dtype_bytes +from paddle.distributed.auto_parallel.mapper import get_comm_volume + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +_global_num_stages = None + +cluster_json = """ +{ + "machines": [ + { + "hostname": "machine0", + "addr": "0.0.0.1", + "port": "768", + "devices": [ + { + "global_id": 0, + "local_id": 0, + "type": "GPU", + "model": "A100-SXM4-40GB", + "sp_gflops": 19500, + "dp_gflops": 9700, + "memory": 40 + }, + { + "global_id": 1, + "local_id": 1, + "type": "GPU", + "model": "A100-SXM4-40GB", + "sp_gflops": 19500, + "dp_gflops": 9700, + "memory": 40 + }, + { + "global_id": 2, + "local_id": 2, + "type": "GPU", + "model": "A100-SXM4-40GB", + "sp_gflops": 19500, + "dp_gflops": 9700, + "memory": 40 + }, + { + "global_id": 3, + "local_id": 3, + "type": "GPU", + "model": "A100-SXM4-40GB", + "sp_gflops": 19500, + "dp_gflops": 9700, + "memory": 40 + }, + { + "global_id": 4, + "local_id": 0, + "type": "NIC" + } + ], + "links": [ + { + "source_global_id": 0, + "target_global_id": 1, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 0, + "target_global_id": 2, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 0, + "target_global_id": 3, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 0, + "target_global_id": 4, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 1, + "target_global_id": 0, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 1, + "target_global_id": 2, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 1, + "target_global_id": 3, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 1, + "target_global_id": 4, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 2, + "target_global_id": 0, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 2, + "target_global_id": 1, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 2, + "target_global_id": 3, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 2, + "target_global_id": 4, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 3, + "target_global_id": 0, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 3, + "target_global_id": 1, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 3, + "target_global_id": 2, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 3, + "target_global_id": 4, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 4, + "target_global_id": 9, + "type": "NET", + "bandwidth": 1 + } + ] + }, + { + "hostname": "machine1", + "addr": "0.0.0.2", + "port": "768", + "devices": [ + { + "global_id": 5, + "local_id": 0, + "type": "GPU", + "model": "Tesla V100-SXM2-32GB", + "sp_gflops": 15700, + "dp_gflops": 7800, + "memory": 32 + }, + { + "global_id": 6, + "local_id": 1, + "type": "GPU", + "model": "Tesla V100-SXM2-32GB", + "sp_gflops": 15700, + "dp_gflops": 7800, + "memory": 32 + }, + { + "global_id": 7, + "local_id": 2, + "type": "GPU", + "model": "Tesla V100-SXM2-32GB", + "sp_gflops": 15700, + "dp_gflops": 7800, + "memory": 32 + }, + { + "global_id": 8, + "local_id": 3, + "type": "GPU", + "model": "Tesla V100-SXM2-32GB", + "sp_gflops": 15700, + "dp_gflops": 7800, + "memory": 32 + }, + { + "global_id": 9, + "local_id": 0, + "type": "NIC" + } + ], + "links": [ + { + "source_global_id": 5, + "target_global_id": 6, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 5, + "target_global_id": 7, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 5, + "target_global_id": 8, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 5, + "target_global_id": 9, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 6, + "target_global_id": 5, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 6, + "target_global_id": 7, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 6, + "target_global_id": 8, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 6, + "target_global_id": 9, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 7, + "target_global_id": 5, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 7, + "target_global_id": 6, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 7, + "target_global_id": 8, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 7, + "target_global_id": 9, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 8, + "target_global_id": 5, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 8, + "target_global_id": 6, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 8, + "target_global_id": 7, + "type": "NVL", + "bandwidth": 42 + }, + { + "source_global_id": 8, + "target_global_id": 9, + "type": "PHB", + "bandwidth": 12 + }, + { + "source_global_id": 9, + "target_global_id": 4, + "type": "NET", + "bandwidth": 1 + } + ] + } + ] +} +""" + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=64, + intermediate_size=4 * 64, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + np.random.seed(2021) + arr0 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward)) + arr1 = np.random.normal(0, 0.02, size=(dim_feedforward, d_model)) + arr2 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward)) + arr3 = np.random.normal(0, 0.02, size=(dim_feedforward, d_model)) + weight_attr0 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr0)) + weight_attr1 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr1)) + weight_attr2 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr2)) + weight_attr3 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr3)) + bias_attr = None + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.linear2 = nn.Linear( + d_model, dim_feedforward, weight_attr2, bias_attr=bias_attr) + self.linear3 = nn.Linear( + dim_feedforward, d_model, weight_attr3, bias_attr=bias_attr) + + def forward(self, input): + if _global_parallel_strategy == "dp_mp_pp": + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh[0], + "dims_mapping": [-1, 1] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh[0], + "dims_mapping": [1, -1] + }) + auto.shard_tensor( + self.linear2.weight, + dist_attr={ + "process_mesh": _global_process_mesh[1], + "dims_mapping": [-1, 1] + }) + auto.shard_tensor( + self.linear3.weight, + dist_attr={ + "process_mesh": _global_process_mesh[1], + "dims_mapping": [1, -1] + }) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + out = self.linear2(out) + out = F.gelu(out, approximate=True) + out = self.linear3(out) + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program,start_program), \ + utils.unique_name.guard(): + batch_size = 4 + hidden_size = 64 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + if _global_parallel_strategy == "dp_mp_pp": + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh[0], + "dims_mapping": [0, -1] + }) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + return loss, train_program, start_program + + +def get_dist_prog(train_program, startup_program, dist_context, rank_id): + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + dist_strategy = fleet.DistributedStrategy() + + # auto completion + complete_train_program = auto.complete_annotation(train_program, + dist_context) + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + # logical partition + dist_train_program, dist_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, startup_program, dist_train_program, + dist_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer() + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + dist_train_program, dist_startup_prog) + reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) + return dist_train_program, dist_startup_prog + + +def is_in_machine(device_local_id, machine): + for device in machine.devices.values(): + if device_local_id == device.local_id: + return True + return False + + +def get_device_local_ids(machine): + local_ids = [] + for device in machine.devices.values(): + local_ids.append[device.local_id] + return local_ids + + +class TestAutoParallelMapper(unittest.TestCase): + def test_mapper_dp_mp_pp(self): + cluster_json_file = "" + cluster_json_object = json.loads(cluster_json) + with open("./auto_parallel_cluster.json", "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file("./auto_parallel_cluster.json") + os.remove("./auto_parallel_cluster.json") + + global _global_parallel_strategy + _global_parallel_strategy = "dp_mp_pp" + global _global_num_stages + _global_num_stages = 2 + global _global_process_mesh + _global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + processes = [0, 1, 2, 3, 4, 5, 6, 7] + + dist_programs = {} + for rank_id in processes: + train_program = static.Program() + startup_program = static.Program() + dist_context = DistributedContext() + dist_train_program, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + # if rank_id == 0: + # print_program_with_dist_attr(dist_train_program, dist_context) + dist_programs[rank_id] = dist_train_program + + rank_mapping = mapping(dist_programs, cluster) + + all_mapped_ranks = set() + for machine_id, machine_mapping in rank_mapping.items(): + machine = cluster.machines[machine_id] + machine_mapped_ranks = set() + machine_mapped_device_local_ids = set() + for rank, device_ids in machine_mapping["ranks"].items(): + # Only allow one process to one device mapping + self.assertEqual(len(device_ids), 1) + self.assertTrue(is_in_machine(device_ids[0], machine)) + machine_mapped_ranks.add(rank) + machine_mapped_device_local_ids.add(device_ids[0]) + self.assertEqual( + len(machine_mapped_ranks), len(machine_mapped_device_local_ids)) + all_mapped_ranks.update(machine_mapped_ranks) + self.assertEqual(set(processes), all_mapped_ranks) + + def test_mapper_misc(self): + self.assertEqual(get_dtype_bytes(paddle.float64), 8) + self.assertEqual(get_dtype_bytes(paddle.float32), 4) + self.assertEqual(get_dtype_bytes(paddle.float16), 2) + self.assertEqual(get_dtype_bytes(paddle.bfloat16), 2) + self.assertEqual(get_dtype_bytes(paddle.int64), 8) + self.assertEqual(get_dtype_bytes(paddle.int32), 4) + self.assertEqual(get_dtype_bytes(paddle.int16), 2) + self.assertEqual(get_dtype_bytes(paddle.int8), 1) + self.assertEqual(get_dtype_bytes(paddle.uint8), 1) + self.assertRaises(ValueError, get_dtype_bytes, "unknown type") + train_program = static.Program() + startup_program = static.Program() + ring_id = 0 + root_id = 0 + nranks = 2 + with fluid.program_guard(train_program, startup_program): + input = layers.data(name="input", shape=[10, 10], dtype='float32') + output = train_program.current_block().create_var( + name="outofbroadcast", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + broadcast_op = train_program.global_block().append_op( + type="c_broadcast", + inputs={'X': input}, + attrs={'ring_id': ring_id, + 'root': root_id}, + outputs={'Out': output}) + self.assertEqual(get_comm_volume(broadcast_op, 0, 1), 400) + self.assertEqual(get_comm_volume(broadcast_op, 1, 0), None) + allgather_op = train_program.global_block().append_op( + type="c_allgather", + inputs={'X': input}, + attrs={'ring_id': ring_id, + 'nranks': nranks}, + outputs={'Out': output}) + self.assertEqual(get_comm_volume(allgather_op, 0, 1), 400) + self.assertEqual(get_comm_volume(allgather_op, 0, 0), None) + reduce_op = train_program.global_block().append_op( + type="c_reduce_sum", + inputs={'X': input}, + attrs={'ring_id': ring_id, + 'root_id': root_id}, + outputs={'Out': output}) + self.assertEqual(get_comm_volume(reduce_op, 0, 1), None) + self.assertEqual(get_comm_volume(reduce_op, 1, 0), 400) + cast_op = train_program.global_block().append_op( + type="cast", + inputs={"X": input}, + outputs={"Out": output}, + attrs={ + "in_dtype": fluid.core.VarDesc.VarType.FP32, + "out_dtype": fluid.core.VarDesc.VarType.FP32 + }) + self.assertRaises(ValueError, get_comm_volume, cast_op, 0, 1) + + +if __name__ == '__main__': + unittest.main()