Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/relax/attrs/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,32 @@ namespace relax {
/*! \brief Attributes used in allreduce operators */
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
String op_type;
bool in_group;

TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
TVM_ATTR_FIELD(op_type).describe(
"The type of reduction operation to be applied to the input data. Now only sum is "
"supported.");
TVM_ATTR_FIELD(in_group).describe(
"Whether the reduction operation performs in group or globally or in group as default.");
}
}; // struct AllReduceAttrs

/*! \brief Attributes used in allgather operators */
struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
int num_workers;
bool in_group;

TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
TVM_ATTR_FIELD(num_workers)
.describe(
"The number of workers, also the number of parts the given buffer should be chunked "
"into.");
TVM_ATTR_FIELD(in_group).describe(
"Whether the allgather operation performs in group or globally or in group as default.");
}
}; // struct AllGatherAttrs

/*! \brief Attributes used in scatter operators */
struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
int num_workers;
Expand Down
15 changes: 10 additions & 5 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,35 +75,40 @@ TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device devic
* \brief Perform an allreduce operation using the underlying communication library
* \param send The array send to perform allreduce on
* \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
* \param in_group Whether the allreduce operation performs globally or in group as default.
* \param recv The array receives the outcome of allreduce
*/
TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv);
/*!
* \brief Perform an allgather operation using the underlying communication library
* \param send The array send to perform allgather on
* \param in_group Whether the allgather operation performs globally or in group as default.
* \param recv The array receives the outcome of allgather
*/
TVM_DLL void AllGather(NDArray send, NDArray recv);
TVM_DLL void AllGather(NDArray send, bool in_group, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param send The buffer to be broadcasted
* \param in_group Whether the broadcast operation performs globally or in group as default.
* \param recv The buffer receives the broadcasted array
*/
TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv);
TVM_DLL void BroadcastFromWorker0(NDArray send, bool in_group, NDArray recv);
/*!
* \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts.
* \param send For worker-0, it must be provided, and otherwise, the buffer must be None.
* The buffer will be divided into equal parts and sent to each worker accordingly.
* \param in_group Whether the scatter operation performs globally or in group as default.
* \param recv The receiving buffer, which must not be None.
*/
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, bool in_group, NDArray recv);
/*!
* \brief Perform a gather operation to worker-0.
* \param send The sending buffer, which must not be None.
* \param in_group Whether the gather operation performs globally or in group as default.
* \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The
* receiving buffer will be divided into equal parts and receive from each worker accordingly.
*/
TVM_DLL void GatherToWorker0(NDArray send, Optional<NDArray> recv);
TVM_DLL void GatherToWorker0(NDArray send, bool in_group, Optional<NDArray> recv);
/*!
* \brief Receive a buffer from worker-0. No-op if the current worker is worker-0.
* \param buffer The buffer to be received
Expand Down
8 changes: 6 additions & 2 deletions include/tvm/runtime/disco/disco_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class DiscoWorker {
* \brief Construct a worker.
* \param worker_id The id of the worker.
* \param num_workers The number of the workers.
* \param num_groups The number of the worker groups.
* \param worker_zero_data The data shared between worker-0 and the controler. It's a nullptr if
* the worker is not worker-0.
* \param channel The communication channel between the worker and the controler.
*/
explicit DiscoWorker(int worker_id, int num_workers, WorkerZeroData* worker_zero_data,
DiscoChannel* channel)
explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
WorkerZeroData* worker_zero_data, DiscoChannel* channel)
: worker_id(worker_id),
num_workers(num_workers),
num_groups(num_groups),
default_device(Device{DLDeviceType::kDLCPU, 0}),
worker_zero_data(worker_zero_data),
channel(channel),
Expand All @@ -68,6 +70,8 @@ class DiscoWorker {
int worker_id;
/*! \brief Total number of workers */
int num_workers;
/*! \brief Total number of workers */
int num_groups;
/*! \brief The default device to allocate data if not specified */
Device default_device;
/*! \brief The name of the underlying collective communication library. */
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ class Session : public ObjectRef {
/*!
* \brief Create a session backed by a thread pool of workers
* \param num_workers The number of workers.
* \param num_groups The number of worker groups.
*/
TVM_DLL static Session ThreadedSession(int num_workers);
TVM_DLL static Session ThreadedSession(int num_workers, int num_groups);
/*!
* \brief Create a session backed by pipe-based multiprocessing
* \param num_workers The number of workers.
* \param num_groups The number of worker groups.
* \param process_pool_creator The name of a global function that takes `num_workers` as an input,
* and returns a PackedFunc, which takes an integer `worker_id` as the input and returns None.
* When `worker-id` is 0, it shuts down the process pool; Otherwise, it retursn a tuple
Expand All @@ -277,8 +279,8 @@ class Session : public ObjectRef {
* \note Worker-0 is always co-located with the controler as a separate thread, and therefore
* worker-0 does not exist in the process pool.
*/
TVM_DLL static Session ProcessSession(int num_workers, String process_pool_creator,
String entrypoint);
TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
String process_pool_creator, String entrypoint);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
};

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/exec/disco_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,23 @@ def fget_item(param_name: str, param_index: int) -> NDArray:

def main():
"""Main worker function"""
if len(sys.argv) != 5:
print("Usage: <worker_id> <num_workers> <read_fd> <write_fd>")
if len(sys.argv) != 6:
print("Usage: <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>")
return
worker_id = int(sys.argv[1])
num_workers = int(sys.argv[2])
num_groups = int(sys.argv[3])
if sys.platform == "win32":
import msvcrt # pylint: disable=import-outside-toplevel,import-error

reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY)
writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
reader = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY)
writer = msvcrt.open_osfhandle(int(sys.argv[5]), os.O_BINARY)
else:
reader = int(sys.argv[3])
writer = int(sys.argv[4])
reader = int(sys.argv[4])
writer = int(sys.argv[5])

worker_func = get_global_func("runtime.disco.WorkerProcess")
worker_func(worker_id, num_workers, reader, writer)
worker_func(worker_id, num_workers, num_groups, reader, writer)


if __name__ == "__main__":
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,16 +1668,21 @@ def interpolate(
)


def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"):
def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name="ccl_allreduce"):
"""CCL Allreduce operator

Parameters
----------
x : Tensor
x : relax.Expr
The input tensor.
op_type: str

op_type : str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.

in_group : bool
Whether the reduction operation performs globally or in group as default.

name : str
Name hint for this operation.

Expand All @@ -1686,7 +1691,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"):
result : Tensor
The result tensor of allreduce.
"""
return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name)
return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name)


def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"):
Expand Down
24 changes: 13 additions & 11 deletions python/tvm/relax/op/ccl/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,26 @@
# specific language governing permissions and limitations
# under the License.
"""Relax Collective Communications Library (CCL) operators"""
from typing import Union
from tvm.relax import PrimValue

from . import _ffi_api
from ...expr import Expr
from ....ir import PrimExpr


def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
def allreduce(x, op_type: str = "sum", in_group: bool = True): # pylint: disable=invalid-name
"""Allreduce operator

Parameters
----------
x : relax.Expr
The input tensor.
op_type: str

op_type : str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.

in_group : bool
Whether the reduction operation performs globally or in group as default.

Returns
-------
result : relax.Expr
Expand All @@ -44,28 +45,29 @@ def allreduce(x, op_type: str = "sum"): # pylint: disable=invalid-name
"Allreduce only supports limited reduction operations, "
f"including {supported_op_types}, but got {op_type}."
)
return _ffi_api.allreduce(x, op_type) # type: ignore # pylint: disable=no-member
return _ffi_api.allreduce(x, op_type, in_group) # type: ignore # pylint: disable=no-member


def allgather(x, num_workers: Union[int, PrimExpr, PrimValue]): # pylint: disable=invalid-name
def allgather(x, num_workers: int, in_group: bool = True): # pylint: disable=invalid-name
"""AllGather operator

Parameters
----------
x : relax.Expr
The input tensor.

num_worker : Union[int, PrimExpr, PrimValue]
num_worker : int
The number of workers to gather data from.

in_group : bool
Whether the gather operation performs globally or in group as default.

Returns
-------
result : relax.Expr
The result of allgather.
"""
if not isinstance(num_workers, PrimValue):
num_workers = PrimValue(num_workers)
return _ffi_api.allgather(x, num_workers) # type: ignore # pylint: disable=no-member
return _ffi_api.allgather(x, num_workers, in_group) # type: ignore # pylint: disable=no-member


def broadcast_from_worker0(x: Expr) -> Expr:
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relax/transform/legalize_ops/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _allreduce(_bb: BlockBuilder, call: Call) -> Expr:
)
return call_dps_packed(
"runtime.disco.allreduce",
[call.args[0], ShapeExpr([op_type_map[op_type_str]])],
[call.args[0], ShapeExpr([op_type_map[op_type_str]]), call.attrs.in_group],
out_sinfo=call.args[0].struct_info,
)

Expand All @@ -57,12 +57,12 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
arg_shape = arg_sinfo.shape.struct_info
for i, shape_value in enumerate(arg_shape.values):
if i == 0:
output_shape.append(shape_value * call.args[1].value)
output_shape.append(shape_value * call.attrs.num_workers)
else:
output_shape.append(shape_value)
return call_dps_packed(
"runtime.disco.allgather",
call.args[0],
[call.args[0], call.attrs.in_group],
out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=arg_sinfo.dtype,
Expand All @@ -75,7 +75,7 @@ def _allgather(_bb: BlockBuilder, call: Call) -> Expr:
def _broadcast_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
return call_dps_packed(
"runtime.disco.broadcast_from_worker0",
call.args[0],
[call.args[0], False],
out_sinfo=call.args[0].struct_info,
)

Expand Down Expand Up @@ -116,7 +116,7 @@ def _scatter_from_worker0(_bb: BlockBuilder, call: Call) -> Expr:
output_shape = output_shape[1:]
return call_dps_packed(
"runtime.disco.scatter_from_worker0",
transpose_var,
[transpose_var, False],
out_sinfo=TensorStructInfo(
shape=output_shape,
dtype=call.args[0].struct_info.dtype,
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/runtime/disco/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class DiscoPopenWorker:
num_workers : int
The total number of workers.

num_groups : int
The total number of worker groups.

stdout: Union[None, int, IO[Any]]
The standard output streams handler specified for the popen process.

Expand All @@ -49,12 +52,14 @@ def __init__( # pylint: disable=too-many-arguments
self,
worker_id: int,
num_workers: int,
num_groups: int,
entrypoint: str = "tvm.exec.disco_worker",
stdout=None,
stderr=None,
):
self.worker_id = worker_id
self.num_workers = num_workers
self.num_groups = num_groups
self.entrypoint = entrypoint
self._proc = None
self._stdout = stdout
Expand Down Expand Up @@ -118,6 +123,7 @@ def start(self):
self.entrypoint,
str(self.worker_id),
str(self.num_workers),
str(self.num_groups),
]
if sys.platform == "win32":
import msvcrt # pylint: disable=import-error,import-outside-toplevel
Expand Down Expand Up @@ -172,9 +178,9 @@ def _kill_child_processes(pid):


@register_func("runtime.disco.create_process_pool")
def _create_process_pool(num_workers: int, entrypoint: str):
def _create_process_pool(num_workers: int, num_groups: int, entrypoint: str):
"""Create a process pool where the workers' are [1, num_workers)."""
pool = [DiscoPopenWorker(i, num_workers, entrypoint) for i in range(1, num_workers)]
pool = [DiscoPopenWorker(i, num_workers, num_groups, entrypoint) for i in range(1, num_workers)]

def result_func(worker_id: int):
nonlocal pool
Expand Down
Loading