diff --git a/python/mlc_llm/cli/worker.py b/python/mlc_llm/cli/worker.py index 0975853865..c4fe53abd9 100644 --- a/python/mlc_llm/cli/worker.py +++ b/python/mlc_llm/cli/worker.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name """Internal DiscoWorker for Disco ProcessSession.""" + import os import sys @@ -31,23 +32,40 @@ def main(): """Main worker function""" - if len(sys.argv) != 5: - print("Usage: ") + + if len(sys.argv) == 5 or len(sys.argv) == 6: + *args, read_fd, write_fd = map(int, sys.argv[1:]) + else: + print( + f"Expected exactly either 4 or 5 arguments, " + f"but received {len(sys.argv)-1} arguments.: {sys.argv}" + ) + # The argument was added in + # https://github.com/apache/tvm/pull/17180. This script + # currently checks the number of arguments present, to + # determine whether `num_groups` was provided. This allows + # the worker.py script provided by MLC-LLM to be compatible + # with either pre-17180 or post-17180 arguments. + # + # After the TVM version used by MLC-LLM includes #17180, the + # usage can be updated to always require `len(sys.argv)==6`. + print("Usage (without num groups): ") + print( + "Usage (with num groups): " + ) return - worker_id = int(sys.argv[1]) - num_workers = int(sys.argv[2]) if sys.platform == "win32": import msvcrt # pylint: disable=import-outside-toplevel,import-error - reader = msvcrt.open_osfhandle(int(sys.argv[3]), os.O_BINARY) - writer = msvcrt.open_osfhandle(int(sys.argv[4]), os.O_BINARY) + reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY) + writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY) else: - reader = int(sys.argv[3]) - writer = int(sys.argv[4]) + reader = read_fd + writer = write_fd worker_func = get_global_func("runtime.disco.WorkerProcess") - worker_func(worker_id, num_workers, reader, writer) + worker_func(*args, reader, writer) if __name__ == "__main__":