Skip to content

Commit 227fd64

Browse files
committed
init
Signed-off-by: Superjomn <[email protected]>
1 parent 9fe63dd commit 227fd64

File tree

3 files changed

+74
-28
lines changed

3 files changed

+74
-28
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,55 @@ def _create_ray_executor(
375375
tp_size=tp_size,
376376
kv_connector_config=kv_connector_config)
377377

378+
@staticmethod
379+
def _create_rpc_executor(
380+
worker_kwargs: Dict,
381+
model_world_size: int,
382+
mpi_session: Optional[MpiSession],
383+
postproc_worker_config: PostprocWorkerConfig,
384+
is_llm_executor: bool,
385+
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
386+
"""Create RPC-based executor (GenerationExecutorRpcProxy)."""
387+
from .rpc_proxy import GenerationExecutorRpcProxy
388+
return GenerationExecutorRpcProxy(
389+
worker_kwargs,
390+
model_world_size=model_world_size,
391+
mpi_session=mpi_session,
392+
postproc_worker_config=postproc_worker_config,
393+
is_llm_executor=is_llm_executor,
394+
kv_connector_config=kv_connector_config)
395+
396+
@staticmethod
397+
def _create_ipc_executor(
398+
worker_kwargs: Dict,
399+
model_world_size: int,
400+
mpi_session: Optional[MpiSession],
401+
postproc_worker_config: PostprocWorkerConfig,
402+
is_llm_executor: bool,
403+
use_worker: bool = False,
404+
kv_connector_config: Optional[KvCacheConnectorConfig] = None):
405+
"""Create IPC-based executor (GenerationExecutorProxy or GenerationExecutorWorker).
406+
407+
Args:
408+
use_worker: If True, creates GenerationExecutorWorker (single process).
409+
If False, creates GenerationExecutorProxy (multi-process with IPC).
410+
"""
411+
if use_worker:
412+
from .worker import GenerationExecutorWorker
413+
return GenerationExecutorWorker(
414+
**worker_kwargs,
415+
is_llm_executor=is_llm_executor,
416+
kv_connector_config=kv_connector_config)
417+
else:
418+
from .proxy import GenerationExecutorProxy
419+
return GenerationExecutorProxy(
420+
worker_kwargs,
421+
model_world_size=model_world_size,
422+
mpi_session=mpi_session,
423+
postproc_worker_config=postproc_worker_config,
424+
is_llm_executor=is_llm_executor,
425+
kv_connector_config=kv_connector_config)
426+
378427
@staticmethod
379428
def create(
380429
engine: Union[Path, Engine],
@@ -394,10 +443,6 @@ def create(
394443
llm_args: Optional[BaseLlmArgs] = None,
395444
**args,
396445
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
397-
# local imports to avoid cyclic importing
398-
from .proxy import GenerationExecutorProxy
399-
from .worker import GenerationExecutorWorker
400-
401446
if world_size == 0:
402447
world_size = mpi_world_size()
403448

@@ -437,7 +482,7 @@ def create(
437482
is_llm_executor=is_llm_executor,
438483
tp_size=args.get("tp_size", 1),
439484
kv_connector_config=kv_connector_config)
440-
elif orchestrator_type is not None:
485+
elif orchestrator_type is not None and orchestrator_type != "rpc":
441486
raise ValueError(
442487
f"Unsupported orchestrator_type: {orchestrator_type}")
443488

@@ -452,21 +497,21 @@ def create(
452497
assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"
453498

454499
if orchestrator_is_rpc:
455-
from .rpc_proxy import GenerationExecutorRpcProxy
456-
return GenerationExecutorRpcProxy(
500+
return GenerationExecutor._create_rpc_executor(
457501
worker_kwargs,
458502
model_world_size=model_world_size,
459503
mpi_session=mpi_session,
460504
postproc_worker_config=postproc_worker_config,
461505
is_llm_executor=is_llm_executor,
462506
kv_connector_config=kv_connector_config)
463507

464-
return GenerationExecutorProxy(
508+
return GenerationExecutor._create_ipc_executor(
465509
worker_kwargs,
466510
model_world_size=model_world_size,
467511
mpi_session=mpi_session,
468512
postproc_worker_config=postproc_worker_config,
469513
is_llm_executor=is_llm_executor,
514+
use_worker=False,
470515
kv_connector_config=kv_connector_config)
471516

472517
# WAR: For the performance of gathering logits, we use single process worker
@@ -478,18 +523,21 @@ def create(
478523
"Using single process worker for TP1, this may hurt streaming generation performance."
479524
)
480525
if orchestrator_is_rpc:
481-
from .rpc_proxy import GenerationExecutorRpcProxy
482-
return GenerationExecutorRpcProxy(
526+
return GenerationExecutor._create_rpc_executor(
483527
worker_kwargs,
484528
model_world_size=model_world_size,
485529
mpi_session=mpi_session,
486530
postproc_worker_config=postproc_worker_config,
487531
is_llm_executor=is_llm_executor,
488532
kv_connector_config=kv_connector_config)
489533

490-
return GenerationExecutorWorker(
491-
**worker_kwargs,
534+
return GenerationExecutor._create_ipc_executor(
535+
worker_kwargs,
536+
model_world_size=model_world_size,
537+
mpi_session=mpi_session,
538+
postproc_worker_config=postproc_worker_config,
492539
is_llm_executor=is_llm_executor,
540+
use_worker=True,
493541
kv_connector_config=kv_connector_config)
494542

495543
# For single-gpu case:
@@ -498,34 +546,35 @@ def create(
498546
# `if __name__ == "__main__":`.
499547
if not platform.system() == 'Windows':
500548
if orchestrator_is_rpc:
501-
from .rpc_proxy import GenerationExecutorRpcProxy
502-
return GenerationExecutorRpcProxy(
549+
return GenerationExecutor._create_rpc_executor(
503550
worker_kwargs,
504551
model_world_size=model_world_size,
505552
mpi_session=mpi_session,
506553
postproc_worker_config=postproc_worker_config,
507554
is_llm_executor=is_llm_executor,
508555
kv_connector_config=kv_connector_config)
509556

510-
return GenerationExecutorProxy(
557+
return GenerationExecutor._create_ipc_executor(
511558
worker_kwargs,
512559
model_world_size=model_world_size,
513560
mpi_session=None, # use mpi4py
514561
postproc_worker_config=postproc_worker_config,
515562
is_llm_executor=is_llm_executor,
563+
use_worker=False,
516564
kv_connector_config=kv_connector_config)
517565
else:
518566
ctx = multiprocessing.get_context("spawn")
519567
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
520568
mpi_session = ProcessPoolExecutorSession(n_workers=1,
521569
mp_context=ctx)
522570
# TODO: add rpc worker here
523-
return GenerationExecutorProxy(
571+
return GenerationExecutor._create_ipc_executor(
524572
worker_kwargs,
525573
model_world_size=model_world_size,
526574
mpi_session=mpi_session,
527575
postproc_worker_config=postproc_worker_config,
528576
is_llm_executor=is_llm_executor,
577+
use_worker=False,
529578
kv_connector_config=kv_connector_config)
530579

531580
def wait_first_completed(

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import pytest
2+
from utils.util import skip_ray
23

3-
# isort: off
4-
from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path
54
from tensorrt_llm import LLM
5+
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
66
from tensorrt_llm.llmapi import KvCacheConfig
77
from tensorrt_llm.lora_helper import LoraConfig
8-
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1
9-
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
10-
from .test_llm import _test_llm_capture_request_error
11-
from utils.util import skip_ray
12-
# isort: on
13-
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
148
from tensorrt_llm.sampling_params import SamplingParams
159

10+
from .lora_test_utils import (
11+
check_llama_7b_multi_lora_from_request_test_harness,
12+
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1)
13+
from .test_llm import (_test_llm_capture_request_error, llama_model_path,
14+
tinyllama_logits_processor_test_harness)
15+
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
16+
1617
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
1718

1819

@@ -71,7 +72,6 @@ def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
7172
cuda_graph_config=None)
7273

7374

74-
@pytest.mark.skip(reason="https://nvbugs/5560921")
7575
@skip_ray
7676
@pytest.mark.gpu2
7777
def test_llm_rpc_tp2():
@@ -90,7 +90,6 @@ def test_llm_rpc_tp2():
9090
assert len(res.outputs[0].token_ids) == 10
9191

9292

93-
@pytest.mark.skip(reason="https://nvbugs/5560921")
9493
@skip_ray
9594
@pytest.mark.gpu2
9695
@pytest.mark.asyncio

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,6 @@ def test_max_num_token_check(self):
960960
llm.generate([ids])
961961

962962

963-
@pytest.mark.skip(reason="https://nvbugs/5560921")
964963
@skip_ray
965964
def test_llm_rpc():
966965
# TODO: remove the with-statement when shutdown hang issue is fixed
@@ -978,7 +977,6 @@ def test_llm_rpc():
978977
assert len(res.outputs[0].token_ids) == 10
979978

980979

981-
@pytest.mark.skip(reason="https://nvbugs/5560921")
982980
@skip_ray
983981
@pytest.mark.asyncio
984982
async def test_llm_rpc_streaming():

0 commit comments

Comments
 (0)