From 317f06229c0e99473d0aa967fdb02aab93e8b51b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 8 May 2024 14:30:05 +0000 Subject: [PATCH] [Disco] Implement `num_workers` property for `disco.Session` Prior to this commit, while the `num_workers` argument was provided to the `disco.Session` object, it could not be determined from an existing `disco.Session` object. As a result, functions that interacted with a multi-GPU setup frequently required separate `num_workers` and `disco_session` argument, which could erroneously be out-of-sync (e.g. passing the incorrect `num_workers`, or omitting the `disco_session` argument when `num_workers>1`). To remove this class of errors, this commit adds a `disco.Session.num_workers` property. The separate `num_workers` argument is no longer necessary, as it can be determined from the `disco.Session` instance. --- include/tvm/runtime/disco/session.h | 2 ++ python/tvm/runtime/disco/session.py | 5 +++++ src/runtime/disco/process_session.cc | 2 ++ src/runtime/disco/session.cc | 2 ++ src/runtime/disco/threaded_session.cc | 2 ++ tests/python/disco/test_session.py | 7 +++++++ 6 files changed, 20 insertions(+) diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 3d4c3e4ea1a3..71fcce75b292 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -197,6 +197,8 @@ class SessionObj : public Object { * The thirtd element is the function to be called. */ TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0; + /*! \brief Get the number of workers in the session. */ + TVM_DLL virtual int64_t GetNumWorkers() = 0; /*! \brief Get a global functions on workers. */ TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0; /*! diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index ee151db7166c..18329eb3f5bd 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -146,6 +146,11 @@ def shutdown(self): """Shut down the Disco session""" _ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member + @property + def num_workers(self) -> int: + """Return the number of workers in the session""" + return _ffi_api.SessionGetNumWorkers(self) # type: ignore # pylint: disable=no-member + def get_global_func(self, name: str) -> DRef: """Get a global function on workers. diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 6474db479e94..dfcf36989c00 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -153,6 +153,8 @@ class ProcessSessionObj final : public BcastSessionObj { ~ProcessSessionObj() { Kill(); } + int64_t GetNumWorkers() { return workers_.size() + 1; } + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { if (worker_id == 0) { this->SyncWorker(worker_id); diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index e74d3819fe04..00f28a7b9f6a 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -37,6 +37,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") .set_body_method(&DRefObj::DebugGetFromRemote); TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom") .set_body_method(&DRefObj::DebugCopyFrom); +TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") + .set_body_method(&SessionObj::GetNumWorkers); TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") .set_body_method(&SessionObj::GetGlobalFunc); TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index c1f2f8539337..7a76a45ed539 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -154,6 +154,8 @@ class ThreadedSessionObj final : public BcastSessionObj { workers_.clear(); } + int64_t GetNumWorkers() { return workers_.size(); } + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { this->SyncWorker(worker_id); return this->workers_.at(worker_id).worker->register_file.at(reg_id); diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 40dcb04911c9..ef8ea2e70a25 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -220,6 +220,13 @@ def transpose_2( np.testing.assert_equal(z_nd, x_np) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_num_workers(session_kind, num_workers): + sess = session_kind(num_workers=num_workers) + assert sess.num_workers == num_workers + + if __name__ == "__main__": test_int(di.ProcessSession) test_float(di.ProcessSession)