diff --git a/tensorrt_llm/executor/ray_executor.py b/tensorrt_llm/executor/ray_executor.py index 512ac4a98db..579aac0a715 100644 --- a/tensorrt_llm/executor/ray_executor.py +++ b/tensorrt_llm/executor/ray_executor.py @@ -82,6 +82,8 @@ def __init__(self, is_llm_executor=is_llm_executor) self.init_rpc_executor() + # Inject the generated HMAC key into worker_kwargs for workers + worker_kwargs['hmac_key'] = self.hmac_key worker_kwargs['rpc_addr'] = self.rpc_addr self.create_workers(RayGPUWorker, worker_kwargs) self.setup_engine_remote() diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 6b820e01267..48f036abeb0 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -168,6 +168,7 @@ def __init__( tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, rpc_addr: Optional[str] = None, + hmac_key: Optional[bytes] = None, ) -> None: global logger from tensorrt_llm.logger import logger @@ -191,7 +192,7 @@ def __init__( if rpc_addr is None: raise RuntimeError( "RPC mode enabled but no rpc_addr provided to RayGPUWorker") - self.init_rpc_worker(self.global_rank, rpc_addr) + self.init_rpc_worker(self.global_rank, rpc_addr, hmac_key) self.start_rpc_server() def setup_engine(self): diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 0fd19de78c5..c9e60e8cf7d 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -108,7 +108,8 @@ def __init__(self, self._client_socket = ZeroMqQueue(address=(address, hmac_key), is_server=False, is_async=True, - use_hmac_encryption=False, + use_hmac_encryption=hmac_key + is not None, socket_type=socket_type, name="rpc_client") self._pending_futures = {} diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index 6635fb6876c..00fb23e94dc 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -108,7 +108,8 @@ def bind(self, address: str = "tcp://*:5555") -> None: self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), is_server=True, is_async=True, - use_hmac_encryption=False, + use_hmac_encryption=self._hmac_key + is not None, socket_type=socket_type, name="rpc_server") logger.info(f"RPCServer is bound to {self._address}") diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 655d77ea7e1..09f93afb80c 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -48,6 +48,8 @@ def __init__( self._create_mpi_session(model_world_size, mpi_session) + # Inject the generated HMAC key into worker_kwargs for workers + worker_kwargs['hmac_key'] = self.hmac_key self.worker_kwargs = worker_kwargs self.launch_workers() diff --git a/tensorrt_llm/executor/rpc_proxy_mixin.py b/tensorrt_llm/executor/rpc_proxy_mixin.py index f3d4b88c57b..c7d7716f4f3 100644 --- a/tensorrt_llm/executor/rpc_proxy_mixin.py +++ b/tensorrt_llm/executor/rpc_proxy_mixin.py @@ -1,6 +1,7 @@ import asyncio import atexit import json +import os import threading from typing import Callable, List, Optional @@ -29,7 +30,8 @@ class RpcExecutorMixin: def init_rpc_executor(self): self.rpc_addr = get_unique_ipc_addr() - self.rpc_client = RPCClient(self.rpc_addr) + self.hmac_key = os.urandom(32) + self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key) self._results = {} self._shutdown_event = threading.Event() diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 13c1f8d1eb0..665e8a07234 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -155,7 +155,10 @@ def main_task( color="yellow") # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async. - rpc_server = RPCServer(worker, num_workers=worker.num_workers) + hmac_key = kwargs.get("hmac_key") + rpc_server = RPCServer(worker, + num_workers=worker.num_workers, + hmac_key=hmac_key) rpc_server.bind(rpc_addr) rpc_server.start() logger_debug(f"[worker] RPC server {mpi_rank()} is started", diff --git a/tensorrt_llm/executor/rpc_worker_mixin.py b/tensorrt_llm/executor/rpc_worker_mixin.py index 14effdd8213..cab53e6b1d2 100644 --- a/tensorrt_llm/executor/rpc_worker_mixin.py +++ b/tensorrt_llm/executor/rpc_worker_mixin.py @@ -25,10 +25,11 @@ class RpcWorkerMixin: # This can be overridden by setting num_workers in the inheriting class NUM_WORKERS = 6 - def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]): + def init_rpc_worker(self, rank: int, rpc_addr: Optional[str], hmac_key: Optional[bytes] = None): if rpc_addr is None: raise RuntimeError("RPC mode enabled but no rpc_addr provided to worker") + self.hmac_key = hmac_key self.rank = rank self.shutdown_event = Event() self._response_queue = Queue() @@ -41,7 +42,7 @@ def start_rpc_server(self): if self.rank == 0: # Use num_workers if set on the instance, otherwise use class default num_workers = getattr(self, "num_workers", RpcWorkerMixin.NUM_WORKERS) - self.rpc_server = RPCServer(self, num_workers=num_workers) + self.rpc_server = RPCServer(self, num_workers=num_workers, hmac_key=self.hmac_key) self.rpc_server.bind(self.rpc_addr) self.rpc_server.start() diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index d61bfd5198d..113e18b15ea 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -95,6 +95,43 @@ def test_tp2(self, num_reqs): assert similar(tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L') + def test_hmac_key_generation(self): + """Test that HMAC key is automatically generated and properly propagated.""" + tokenizer = TransformersTokenizer.from_pretrained(model_path) + prompt = "A B C D" + prompt_token_ids = tokenizer.encode(prompt) + max_tokens = 8 + + with self.create_proxy(tp_size=1) as proxy: + assert proxy.hmac_key is not None, "HMAC key should be generated" + assert len( + proxy.hmac_key + ) == 32, f"HMAC key should be 32 bytes, got {len(proxy.hmac_key)}" + + # Verify key is properly stored in worker_kwargs + assert 'hmac_key' in proxy.worker_kwargs, "HMAC key should be in worker_kwargs" + assert proxy.worker_kwargs[ + 'hmac_key'] is not None, "HMAC key in worker_kwargs should not be None" + + # Verify both references point to the same key object + assert proxy.hmac_key is proxy.worker_kwargs['hmac_key'], \ + "HMAC key should be the same object in both locations" + + logger_debug( + f"[Test] HMAC key verified: length={len(proxy.hmac_key)} bytes", + color="green") + + # Verify RPC communication works with the generated key + sampling_params = SamplingParams(max_tokens=max_tokens) + result = proxy.generate(prompt_token_ids, sampling_params) + assert similar( + tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L' + ), "Generation should work with auto-generated HMAC key" + + logger_debug( + f"[Test] HMAC key test passed: RPC communication successful", + color="green") + if __name__ == "__main__": TestRpcProxy().test_tp1(20)