Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,55 @@ def _create_ray_executor(
tp_size=tp_size,
kv_connector_config=kv_connector_config)

@staticmethod
def _create_rpc_executor(
worker_kwargs: Dict,
model_world_size: int,
mpi_session: Optional[MpiSession],
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
"""Create RPC-based executor (GenerationExecutorRpcProxy)."""
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

@staticmethod
def _create_ipc_executor(
worker_kwargs: Dict,
model_world_size: int,
mpi_session: Optional[MpiSession],
postproc_worker_config: PostprocWorkerConfig,
is_llm_executor: bool,
use_worker: bool = False,
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
"""Create IPC-based executor (GenerationExecutorProxy or GenerationExecutorWorker).

Args:
use_worker: If True, creates GenerationExecutorWorker (single process).
If False, creates GenerationExecutorProxy (multi-process with IPC).
"""
if use_worker:
from .worker import GenerationExecutorWorker
return GenerationExecutorWorker(
**worker_kwargs,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)
else:
from .proxy import GenerationExecutorProxy
return GenerationExecutorProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

@staticmethod
def create(
engine: Union[Path, Engine],
Expand All @@ -394,10 +443,6 @@ def create(
llm_args: Optional[BaseLlmArgs] = None,
**args,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
from .worker import GenerationExecutorWorker

if world_size == 0:
world_size = mpi_world_size()

Expand Down Expand Up @@ -437,7 +482,7 @@ def create(
is_llm_executor=is_llm_executor,
tp_size=args.get("tp_size", 1),
kv_connector_config=kv_connector_config)
elif orchestrator_type is not None:
elif orchestrator_type is not None and orchestrator_type != "rpc":
raise ValueError(
f"Unsupported orchestrator_type: {orchestrator_type}")

Expand All @@ -452,21 +497,21 @@ def create(
assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"

if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorProxy(
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False,
kv_connector_config=kv_connector_config)

# WAR: For the performance of gathering logits, we use single process worker
Expand All @@ -478,18 +523,21 @@ def create(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorWorker(
**worker_kwargs,
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=True,
kv_connector_config=kv_connector_config)

# For single-gpu case:
Expand All @@ -498,34 +546,35 @@ def create(
# `if __name__ == "__main__":`.
if not platform.system() == 'Windows':
if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
return GenerationExecutor._create_rpc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorProxy(
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False,
kv_connector_config=kv_connector_config)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
mpi_session = ProcessPoolExecutorSession(n_workers=1,
mp_context=ctx)
# TODO: add rpc worker here
return GenerationExecutorProxy(
return GenerationExecutor._create_ipc_executor(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
use_worker=False,
kv_connector_config=kv_connector_config)

def wait_first_completed(
Expand Down
2 changes: 0 additions & 2 deletions tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_llama_7b_multi_lora_tp2():
cuda_graph_config=None)


@pytest.mark.skip(reason="https://nvbugs/5560921")
@skip_ray
@pytest.mark.gpu2
def test_llm_rpc_tp2():
Expand All @@ -81,7 +80,6 @@ def test_llm_rpc_tp2():
assert len(res.outputs[0].token_ids) == 10


@pytest.mark.skip(reason="https://nvbugs/5560921")
@skip_ray
@pytest.mark.gpu2
@pytest.mark.asyncio
Expand Down
2 changes: 0 additions & 2 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,6 @@ def test_max_num_token_check(self):
llm.generate([ids])


@pytest.mark.skip(reason="https://nvbugs/5560921")
@skip_ray
def test_llm_rpc():
# TODO: remove the with-statement when shutdown hang issue is fixed
Expand All @@ -978,7 +977,6 @@ def test_llm_rpc():
assert len(res.outputs[0].token_ids) == 10


@pytest.mark.skip(reason="https://nvbugs/5560921")
@skip_ray
@pytest.mark.asyncio
async def test_llm_rpc_streaming():
Expand Down