Skip to content

Commit 417f9f1

Browse files
junrushaoLeshengJin
andcommitted
[Unity] Disco: A Framework-Agnostic SPMD Runtime for Distributed Inference/Training
Disco is a distributed runtime that consists of a controler and a cluster of workers. The controler is responsible for managing the workers by broadcasting commands to all the workers together, and the workers are responsible for executing the commands and. The controler and workers communicate with each other through a bi-directional channel. Different from a generic system, Disco is designed to as "single-program-multiple-data" (SPMD) runtime, which means that all the workers execute the same instruction at the same time, but the data they are working on may be different. For example, in data parallelism, each worker may work on a different batches of the data, but they all execute the same set of instructions. Therefore, imagine there is a virtual machine that executes the program, the structures of workers' register files could be considered as "identical" (single program) although the values may differ (multiple data). **DRef.** Following the design above, consider the program in SPMD in a virtual ISA, then each worker is a virtual machine instance to execute the ISA maintaining its own register file. The controler denotes each of their register files with a unique integer "register id", and the workers use this id to refer to the register file that resides on itself. DRef is a control-side object backed by such a register id. The data it contains is not assumed to be directly accessible by the controler, with an exception for worker-0, which is a special worker that is always co-located with the controler. **Worker-0.** Worker-0 is a special worker that is always co-located with the controler. It is assumed that the controler can synchronize with and access the registers of worker-0. The Disco session provides multiple APIs to interact specifically with the worker-0. To shared data with other workers, a common paradigm in Disco is to copy data from the controler-side NDArray to the worker-0, and then copy it to other workers using primitives on the data plane, for example, `broadcast` and `send`. **Control plane.** The controler broadcasts commands to all the workers as control signals. For example, the control may ask all workers to load a library or call a function respectively. Common control signals include: shutdown, retrievel a global PackedFunc, call packed function, etc. The controler is assumed to keep a message channel to each worker to implement the broadcast behavior, and the message channel may vary depends on usecases. **Data plane.** The data channel is usually used to exchange data between workers, especially for tensor data which is usually large. For example, performing an allreduce operator for sharded matrix multiplication, or broadcasting for an input tensor. For efficiency, the data channel is usually backed by NCCL on NVIDIA GPUs, RCCL on AMD GPUs, or MPI on CPUs. **Session.** A Disco session is a primary interface to interact with the Disco runtime, serving as a global context that manages the control and workers. It could be implemented as a multi-threaded with a pool of workers for single-node multi-gpu scenarios, or TCP sockets for workloads that span over a cluster of nodes. **Channel.** Disco channel is a bi-directional communication channel between the controler and workers for exchanging control signals. It is no different from a generic RPC channel, but adopts TVM's PackedFunc calling convention to support polymorphic and variadic arguments. Co-Authored-by: Lesheng Jin <[email protected]>
1 parent c57da13 commit 417f9f1

File tree

28 files changed

+2614
-0
lines changed

28 files changed

+2614
-0
lines changed

include/tvm/relax/attrs/ccl.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/ccl.h
22+
* \brief Attributes for ccl operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_CCL_H_
25+
#define TVM_RELAX_ATTRS_CCL_H_
26+
27+
#include <tvm/relax/expr.h>
28+
29+
namespace tvm {
30+
namespace relax {
31+
32+
/*! \brief Attributes used in allreduce operators */
33+
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
34+
String op_type;
35+
36+
TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
37+
TVM_ATTR_FIELD(op_type).describe(
38+
"The type of reduction operation to be applied to the input data. Now only sum is "
39+
"supported.");
40+
}
41+
}; // struct AllReduceAttrs
42+
43+
} // namespace relax
44+
} // namespace tvm
45+
46+
#endif // TVM_RELAX_ATTRS_CCL_H_
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file session.h
21+
* \brief This file serves as the entry point of Disco and defines key data structures and
22+
* interfaces.
23+
*
24+
* Disco is a distributed runtime that consists of a controler and a cluster of workers. The
25+
* controler is responsible for managing the workers by broadcasting commands to all the workers
26+
* together, and the workers are responsible for executing the commands and. The controler and
27+
* workers communicate with each other through a bi-directional channel.
28+
*
29+
* Different from a generic system, Disco is designed to as "single-program-multiple-data" (SPMD)
30+
* runtime, which means that all the workers execute the same instruction at the same time, but the
31+
* data they are working on may be different. For example, in data parallelism, each worker may
32+
* work on a different batches of the data, but they all execute the same set of instructions.
33+
* Therefore, imagine there is a virtual machine that executes the program, the structures of
34+
* workers' register files could be considered as "identical" (single program) although the values
35+
* may differ (multiple data).
36+
*
37+
* **DRef.** Following the design above, consider the program in SPMD in a virtual ISA, then each
38+
* worker is a virtual machine instance to execute the ISA maintaining its own register file.
39+
* The controler denotes each of their register files with a unique integer "register id",
40+
* and the workers use this id to refer to the register file that resides on itself.
41+
* DRef is a control-side object backed by such a register id. The data it contains is not assumed
42+
* to be directly accessible by the controler, with an exception for worker-0, which is a special
43+
* worker that is always co-located with the controler.
44+
*
45+
* **Worker-0.** Worker-0 is a special worker that is always co-located with the controler.
46+
* It is assumed that the controler can synchronize with and access the registers of worker-0.
47+
* The Disco session provides multiple APIs to interact specifically with the worker-0.
48+
* To shared data with other workers, a common paradigm in Disco is to copy data from the
49+
* controler-side NDArray to the worker-0, and then copy it to other workers using primitives on
50+
* the data plane, for example, `broadcast` and `send`.
51+
*
52+
* **Control plane.** The controler broadcasts commands to all the workers as control signals.
53+
* For example, the control may ask all workers to load a library or call a function respectively.
54+
* Common control signals include: shutdown, retrievel a global PackedFunc, call packed function,
55+
* etc. The controler is assumed to keep a message channel to each worker to implement the broadcast
56+
* behavior, and the message channel may vary depends on usecases.
57+
*
58+
* **Data plane.** The data channel is usually used to exchange data between workers, especially for
59+
* tensor data which is usually large. For example, performing an allreduce operator for sharded
60+
* matrix multiplication, or broadcasting for an input tensor. For efficiency, the data channel is
61+
* usually backed by NCCL on NVIDIA GPUs, RCCL on AMD GPUs, or MPI on CPUs.
62+
*
63+
* **Session.** A Disco session is a primary interface to interact with the Disco runtime, serving
64+
* as a global context that manages the control and workers. It could be implemented as a
65+
* multi-threaded with a pool of workers for single-node multi-gpu scenarios, or TCP sockets for
66+
* workloads that span over a cluster of nodes.
67+
*
68+
* **Channel.** Disco channel is a bi-directional communication channel between the controler and
69+
* workers for exchanging control signals. It is no different from a generic RPC channel, but
70+
* adopts TVM's PackedFunc calling convention to support polymorphic and variadic arguments.
71+
*/
72+
#ifndef TVM_RUNTIME_DISCO_SESSION_H_
73+
#define TVM_RUNTIME_DISCO_SESSION_H_
74+
75+
#include <tvm/runtime/object.h>
76+
#include <tvm/runtime/packed_func.h>
77+
78+
#include <string>
79+
#include <utility>
80+
81+
namespace tvm {
82+
namespace runtime {
83+
84+
/*!
85+
* \brief All possible kinds of Disco commands.
86+
*/
87+
enum class DiscoAction : int32_t {
88+
kShutDown = 0,
89+
kKillReg = 1,
90+
kGetGlobalFunc = 2,
91+
kCallPacked = 3,
92+
kSyncWorker = 4,
93+
kCopyFromWorker0 = 5,
94+
kCopyToWorker0 = 6,
95+
};
96+
97+
/*! \brief Converts the enum class `DiscoAction` to string */
98+
inline std::string DiscoAction2String(DiscoAction action) {
99+
switch (action) {
100+
case DiscoAction::kShutDown:
101+
return "kShutDown";
102+
case DiscoAction::kKillReg:
103+
return "kKillReg";
104+
case DiscoAction::kGetGlobalFunc:
105+
return "kGetGlobalFunc";
106+
case DiscoAction::kCallPacked:
107+
return "kCallPacked";
108+
case DiscoAction::kSyncWorker:
109+
return "kSyncWorker";
110+
case DiscoAction::kCopyFromWorker0:
111+
return "kCopyFromWorker0";
112+
case DiscoAction::kCopyToWorker0:
113+
return "kCopyToWorker0";
114+
}
115+
LOG(FATAL) << "ValueError: Unknown DiscoAction: " << static_cast<int>(action);
116+
}
117+
118+
/*!
119+
* \brief An object that exists on all workers.
120+
*
121+
* The controler assigns a unique "register id" to each object, and the worker uses this id to
122+
* refer to the object residing on itself.
123+
*/
124+
class DRefObj : public Object {
125+
public:
126+
/*!\ brief Send dellocation command for `reg_id` */
127+
inline ~DRefObj();
128+
/*!
129+
* \brief Get the value of a DRef from a remote worker.
130+
* \param worker_id The id of the worker to be fetched from.
131+
* \return The value of the register.
132+
*/
133+
inline TVMRetValue DebugGetFromRemote(int worker_id);
134+
135+
static constexpr const char* _type_key = "runtime.disco.DRef";
136+
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
137+
TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object);
138+
139+
/*! \brief The id of the register */
140+
int64_t reg_id;
141+
/*! \brief Back-pointer to the host controler session */
142+
ObjectRef session{nullptr};
143+
};
144+
145+
/*!
146+
* \brief Managed reference to DRefObj.
147+
* \sa DRefObj
148+
* \note No public constructor is provided as it is not supposed to be directly created by users.
149+
*/
150+
class DRef : public ObjectRef {
151+
public:
152+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj);
153+
};
154+
155+
/*!
156+
* \brief A Disco interactive session. It allows users to interact with the Disco command queue with
157+
* various PackedFunc calling convention.
158+
*/
159+
class SessionObj : public Object {
160+
public:
161+
virtual ~SessionObj() = default;
162+
/*!
163+
* \brief Call a PackedFunc on workers providing variadic arguments.
164+
* \tparam Args In the variadic arguments, the supported types include:
165+
* - integers and floating point numbers;
166+
* - DataType;
167+
* - Device;
168+
* - std::string;
169+
* - DRef.
170+
* Examples of unsupported types:
171+
* - NDArray, DLTensor;
172+
* - TVM Objects, including PackedFunc, Module and String;
173+
* \param func The function to be called.
174+
* \param args The variadic arguments.
175+
* \return The return value of function call
176+
*/
177+
template <typename... Args>
178+
DRef TVM_ALWAYS_INLINE CallPacked(const DRef& func, Args&&... args);
179+
/*! \brief Get a global functions on workers. */
180+
virtual DRef GetGlobalFunc(const std::string& name) = 0;
181+
/*!
182+
* \brief Copy the controler-side NDArray to worker-0
183+
* \param host_array The array to be copied to worker-0
184+
* \param remote_array The NDArray on worker-0
185+
*/
186+
virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
187+
/*!
188+
* \brief Copy an NDArray from worker-0 to the controler-side NDArray
189+
* \param host_array The array to be copied to worker-0
190+
* \param remote_array The NDArray on worker-0
191+
*/
192+
virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
193+
/*!
194+
* \brief Synchrnoize the controler with a worker, and it will wait until worker finishes
195+
* executing this instruction.
196+
* \param worker_id The id of the worker to be synced with.
197+
* \note This function is usually used for worker-0, because it is the only worker that is
198+
* assumed to collocate with the controler. Syncing with other workers may not be supported.
199+
*/
200+
virtual void SyncWorker(int worker_id) = 0;
201+
/*! \brief Signal all the workers to shutdown */
202+
virtual void Shutdown() = 0;
203+
/*!
204+
* \brief Get the value of a register from a remote worker.
205+
* \param reg_id The id of the register to be fetched.
206+
* \param worker_id The id of the worker to be fetched from.
207+
* \return The value of the register.
208+
*/
209+
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
210+
211+
static constexpr const char* _type_key = "runtime.disco.Session";
212+
TVM_DECLARE_BASE_OBJECT_INFO(SessionObj, Object);
213+
214+
struct FFI;
215+
friend struct SessionObj::FFI;
216+
friend class DRefObj;
217+
218+
protected:
219+
/*! \brief Deallocate a register id, kill it on all workers, and append it to `free_regs_`. */
220+
virtual void DeallocReg(int reg_id) = 0;
221+
/*! \brief Call packed function on each worker using a packed sequence */
222+
virtual DRef CallWithPacked(const TVMArgs& args) = 0;
223+
};
224+
225+
/*!
226+
* \brief Managed reference to SessionObj
227+
* \sa SessionObj
228+
*/
229+
class Session : public ObjectRef {
230+
public:
231+
/*! \brief Create a session backed by a thread pool of workers */
232+
static Session ThreadedSession(int num_workers);
233+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
234+
};
235+
236+
/*!
237+
* \brief A bi-directional channel for controler-worker communication.
238+
* This channel is primarily used to transfer control messages but not data.
239+
*/
240+
class DiscoChannel {
241+
public:
242+
/*! \brief Send a packed sequence to the receiver */
243+
virtual void Send(const TVMArgs& args) = 0;
244+
/*! \brief Receive a packed sequence from worker */
245+
virtual TVMArgs Recv() = 0;
246+
/*! \brief Reply a packed sequence to the sender */
247+
virtual void Reply(const TVMArgs& args) = 0;
248+
/*! \brief Receive a reply from the worker */
249+
virtual TVMArgs RecvReply() = 0;
250+
};
251+
252+
// Implementation details
253+
254+
DRefObj::~DRefObj() {
255+
if (this->session.defined()) {
256+
Downcast<Session>(this->session)->DeallocReg(reg_id);
257+
}
258+
}
259+
260+
TVMRetValue DRefObj::DebugGetFromRemote(int worker_id) {
261+
return Downcast<Session>(this->session)->DebugGetFromRemote(this->reg_id, worker_id);
262+
}
263+
264+
template <typename... Args>
265+
DRef SessionObj::CallPacked(const DRef& func, Args&&... args) {
266+
constexpr int offset = 3;
267+
constexpr int kNumArgs = offset + sizeof...(Args);
268+
TVMValue values[kNumArgs];
269+
int type_codes[kNumArgs];
270+
PackArgs(values, type_codes,
271+
/*.0=*/static_cast<int>(DiscoAction::kCallPacked), // action
272+
/*.1=*/0, // reg_id, which will be updated by this->CallWithPacked
273+
/*.2=*/func, // the function to be called
274+
std::forward<Args>(args)...);
275+
return this->CallWithPacked(TVMArgs(values, type_codes, kNumArgs));
276+
}
277+
278+
} // namespace runtime
279+
} // namespace tvm
280+
#endif // TVM_RUNTIME_DISCO_SESSION_H_

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from . import image
3939
from . import memory
4040
from . import nn
41+
from . import ccl
4142

4243
# Register operator gradient functions
4344
from . import _op_gradient
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=wildcard-import
18+
"""CCL related operators."""
19+
from .ccl import *
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Operators serving for Collective Communications Library (CCL) operators"""
18+
import tvm._ffi
19+
20+
tvm._ffi._init_api("relax.op.ccl", __name__)

0 commit comments

Comments
 (0)