Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class SessionObj : public Object {
* The thirtd element is the function to be called.
*/
TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0;
/*! \brief Get the number of workers in the session. */
TVM_DLL virtual int64_t GetNumWorkers() = 0;
/*! \brief Get a global functions on workers. */
TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
/*!
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/runtime/disco/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def shutdown(self):
"""Shut down the Disco session"""
_ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member

@property
def num_workers(self) -> int:
"""Return the number of workers in the session"""
return _ffi_api.SessionGetNumWorkers(self) # type: ignore # pylint: disable=no-member

def get_global_func(self, name: str) -> DRef:
"""Get a global function on workers.

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/disco/process_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class ProcessSessionObj final : public BcastSessionObj {

~ProcessSessionObj() { Kill(); }

int64_t GetNumWorkers() { return workers_.size() + 1; }

TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
if (worker_id == 0) {
this->SyncWorker(worker_id);
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/disco/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote")
.set_body_method<DRef>(&DRefObj::DebugGetFromRemote);
TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom")
.set_body_method<DRef>(&DRefObj::DebugCopyFrom);
TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers")
.set_body_method<Session>(&SessionObj::GetNumWorkers);
TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc")
.set_body_method<Session>(&SessionObj::GetGlobalFunc);
TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0")
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/disco/threaded_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ class ThreadedSessionObj final : public BcastSessionObj {
workers_.clear();
}

int64_t GetNumWorkers() { return workers_.size(); }

TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) {
this->SyncWorker(worker_id);
return this->workers_.at(worker_id).worker->register_file.at(reg_id);
Expand Down
7 changes: 7 additions & 0 deletions tests/python/disco/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def transpose_2(
np.testing.assert_equal(z_nd, x_np)


@pytest.mark.parametrize("session_kind", _all_session_kinds)
@pytest.mark.parametrize("num_workers", [1, 2, 4])
def test_num_workers(session_kind, num_workers):
sess = session_kind(num_workers=num_workers)
assert sess.num_workers == num_workers


if __name__ == "__main__":
test_int(di.ProcessSession)
test_float(di.ProcessSession)
Expand Down