Skip to content

Commit 1b6c00d

Browse files
authored
[Disco] Implement SocketSession (#17182)
* [Disco] Implement SocketSession Implements SocketSession that connects multiple local worker processes/threads over multiple distributed nodes via TCP socket. * doc * lint * resolve conflcit * lint * add local worker id * lint * lint * disable for hexagon * remove from header
1 parent 08d7519 commit 1b6c00d

File tree

15 files changed

+676
-110
lines changed

15 files changed

+676
-110
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ if(BUILD_FOR_HEXAGON)
387387
add_definitions(-DDMLC_CXX11_THREAD_LOCAL=0)
388388
endif()
389389

390+
# distributed disco runtime are disabled for hexagon
391+
if (NOT BUILD_FOR_HEXAGON)
392+
tvm_file_glob(GLOB RUNTIME_DISCO_DISTRIBUTED_SRCS src/runtime/disco/distributed/*.cc)
393+
list(APPEND RUNTIME_SRCS ${RUNTIME_DISCO_DISTRIBUTED_SRCS})
394+
endif()
395+
390396
# Package runtime rules
391397
if(NOT USE_RTTI)
392398
add_definitions(-DDMLC_ENABLE_RTTI=0)

include/tvm/runtime/disco/disco_worker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class DiscoWorker {
5252
explicit DiscoWorker(int worker_id, int num_workers, int num_groups,
5353
WorkerZeroData* worker_zero_data, DiscoChannel* channel)
5454
: worker_id(worker_id),
55+
local_worker_id(worker_id),
5556
num_workers(num_workers),
5657
num_groups(num_groups),
5758
default_device(Device{DLDeviceType::kDLCPU, 0}),
@@ -68,6 +69,9 @@ class DiscoWorker {
6869

6970
/*! \brief The id of the worker.*/
7071
int worker_id;
72+
/*! \brief The local id of the worker. This can be different from worker_id if the session is
73+
* consisted with multiple sub-sessions. */
74+
int local_worker_id;
7175
/*! \brief Total number of workers */
7276
int num_workers;
7377
/*! \brief Total number of workers */

include/tvm/runtime/disco/session.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ class Session : public ObjectRef {
281281
*/
282282
TVM_DLL static Session ProcessSession(int num_workers, int num_groups,
283283
String process_pool_creator, String entrypoint);
284+
284285
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Session, ObjectRef, SessionObj);
285286
};
286287

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Launch disco session in the remote node and connect to the server."""
19+
import sys
20+
import tvm
21+
from . import disco_worker as _ # pylint: disable=unused-import
22+
23+
24+
if __name__ == "__main__":
25+
if len(sys.argv) != 4:
26+
print("Usage: <server_host> <server_port> <num_workers>")
27+
sys.exit(1)
28+
29+
server_host = sys.argv[1]
30+
server_port = int(sys.argv[2])
31+
num_workers = int(sys.argv[3])
32+
func = tvm.get_global_func("runtime.disco.RemoteSocketSession")
33+
func(server_host, server_port, num_workers)

python/tvm/runtime/disco/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
ProcessSession,
2323
Session,
2424
ThreadedSession,
25+
SocketSession,
2526
)

python/tvm/runtime/disco/session.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,29 @@ def _configure_structlog(self) -> None:
574574
func(config, os.getpid())
575575

576576

577+
@register_func("runtime.disco.create_socket_session_local_workers")
578+
def _create_socket_session_local_workers(num_workers) -> Session:
579+
"""Create the local session for each distributed node over socket session."""
580+
return ProcessSession(num_workers)
581+
582+
583+
@register_object("runtime.disco.SocketSession")
584+
class SocketSession(Session):
585+
"""A Disco session backed by socket-based multi-node communication."""
586+
587+
def __init__(
588+
self, num_nodes: int, num_workers_per_node: int, num_groups: int, host: str, port: int
589+
) -> None:
590+
self.__init_handle_by_constructor__(
591+
_ffi_api.SocketSession, # type: ignore # pylint: disable=no-member
592+
num_nodes,
593+
num_workers_per_node,
594+
num_groups,
595+
host,
596+
port,
597+
)
598+
599+
577600
@register_func("runtime.disco._configure_structlog")
578601
def _configure_structlog(pickled_config: bytes, parent_pid: int) -> None:
579602
"""Configure structlog for all disco workers

src/runtime/disco/bcast_session.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ class BcastSessionObj : public SessionObj {
6565
* \param TVMArgs The input arguments in TVM's PackedFunc calling convention
6666
*/
6767
virtual void BroadcastPacked(const TVMArgs& args) = 0;
68+
69+
/*!
70+
* \brief Send a packed sequence to a worker. This function is usually called by the controler to
71+
* communicate with worker-0, because the worker-0 is assumed to be always collocated with the
72+
* controler. Sending to other workers may not be supported.
73+
* \param worker_id The worker id to send the packed sequence to.
74+
* \param args The packed sequence to send.
75+
*/
76+
virtual void SendPacked(int worker_id, const TVMArgs& args) = 0;
77+
6878
/*!
6979
* \brief Receive a packed sequence from a worker. This function is usually called by the
7080
* controler to communicate with worker-0, because the worker-0 is assumed to be always
@@ -83,6 +93,16 @@ class BcastSessionObj : public SessionObj {
8393

8494
struct Internal;
8595
friend struct Internal;
96+
friend class SocketSessionObj;
97+
friend class RemoteSocketSession;
98+
};
99+
100+
/*!
101+
* \brief Managed reference to BcastSessionObj.
102+
*/
103+
class BcastSession : public Session {
104+
public:
105+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BcastSession, Session, BcastSessionObj);
86106
};
87107

88108
} // namespace runtime

src/runtime/disco/disco_worker.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,15 @@ struct DiscoWorker::Impl {
120120
}
121121

122122
static void CopyFromWorker0(DiscoWorker* self, int reg_id) {
123-
if (self->worker_zero_data != nullptr) {
123+
if (self->worker_id == 0) {
124124
NDArray tgt = GetNDArrayFromHost(self);
125125
NDArray src = GetReg(self, reg_id);
126126
tgt.CopyFrom(src);
127127
}
128128
}
129129

130130
static void CopyToWorker0(DiscoWorker* self, int reg_id) {
131-
if (self->worker_zero_data != nullptr) {
131+
if (self->worker_id == 0) {
132132
NDArray src = GetNDArrayFromHost(self);
133133
NDArray tgt = GetReg(self, reg_id);
134134
tgt.CopyFrom(src);

0 commit comments

Comments
 (0)