Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import hashlib
import hmac
import os
Expand Down Expand Up @@ -154,6 +155,52 @@ async def put_async(self, obj: Any):

nvtx_mark("ipc.send", color="blue", category="IPC")

async def put_async_noblock(self, obj: Any):
self.setup_lazily()
try:
if self.use_hmac_encryption:
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
await self.socket.send(signed_data, flags=zmq.NOBLOCK)
else:
await self.socket.send_pyobj(obj, flags=zmq.NOBLOCK)
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e

async def put_async_with_timeout(self, obj: Any, timeout: float = 5.0):
"""
Send an object with timeout to detect connection failures.

Args:
obj: The object to send
timeout: Timeout in seconds for the send operation

Raises:
zmq.Again: If send operation times out (peer may be disconnected)
Exception: Other send errors
"""
self.setup_lazily()
try:
if self.use_hmac_encryption:
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
# Use asyncio.wait_for to implement timeout instead of zmq.NOBLOCK
await asyncio.wait_for(self.socket.send(signed_data),
timeout=timeout)
else:
await asyncio.wait_for(self.socket.send_pyobj(obj),
timeout=timeout)
except asyncio.TimeoutError:
# Convert timeout to zmq.Again to maintain compatibility with existing error handling
raise zmq.Again(
"Send operation timed out - peer may be disconnected")
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e

def get(self) -> Any:
self.setup_lazily()

Expand Down Expand Up @@ -196,6 +243,9 @@ async def get_async(self) -> Any:
obj = await self.socket.recv_pyobj()
return obj

async def get_async_noblock(self, timeout: float = 0.5) -> Any:
return await asyncio.wait_for(self.get_async(), timeout)

def close(self):
if self.socket:
self.socket.close()
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/executor/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .rpc_client import RPCClient
from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest,
RPCResponse, RPCStreamingError, RPCTimeout)
from .rpc_server import RPCServer, Server

__all__ = [
"RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout",
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse",
"RPCParams"
]
Loading