Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 167 additions & 7 deletions tensorrt_llm/serve/cluster_storage.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import abc
import asyncio
import logging
import time
from dataclasses import dataclass
from enum import IntEnum
from functools import wraps
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional

import aiohttp
import etcd3
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel

logger = logging.getLogger('uvicorn.error')
from tensorrt_llm.logger import logger


class StorageItem(BaseModel):
Expand Down Expand Up @@ -91,7 +91,7 @@ async def delete(self, key: str) -> bool:
async def watch(self, key_prefix: str) -> WatchEventQueue:
...

# unwatch the key prefix, if the key prefix is not in the watch list, raise an error
# unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError
async def unwatch(self, key_prefix: str) -> None:
...

Expand All @@ -106,12 +106,16 @@ async def get_prefix(self,
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
if cluster_uri.startswith("http"):
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
elif cluster_uri.startswith("etcd"):
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")


def create_cluster_storage_client(cluster_uri, cluster_name):
def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs):
if cluster_uri.startswith("http"):
return HttpClusterStorageClient(cluster_uri, cluster_name)
return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs)
elif cluster_uri.startswith("etcd"):
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")


Expand Down Expand Up @@ -241,7 +245,7 @@ async def unwatch(self, key_prefix: str) -> None:
if key_prefix in self._watch_handles:
self._watch_handles.pop(key_prefix)
else:
raise ValueError(
raise KeyError(
f"Key prefix {key_prefix} not in watch list, {self._watch_handles.keys()}"
)

Expand Down Expand Up @@ -377,3 +381,159 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
async def unwatch(self, key_prefix: str) -> None:
raise NotImplementedError(
"Unwatch functionality not implemented for HTTP client")


class Etcd3WatchEventQueue(WatchEventQueue):

def __init__(self,
key_prefix: str,
cancel_event: Callable[[], None] = None):
self.key_prefix = key_prefix
self._cancel_event = cancel_event
self.events = asyncio.Queue()

def cancel_event(self):
if self._cancel_event:
self._cancel_event()

def set_cancel_event(self, cancel_event: Callable[[], None]):
self._cancel_event = cancel_event

def __del__(self):
self.cancel_event()

def add_event(self, watch_resp):
try:
for event in watch_resp.events:
# Event type is not in public interface of etcd3
event_type = WatchEventType.SET if "Put" in event.__class__.__name__ else WatchEventType.DELETE
self.events.put_nowait(
WatchEvent(
storage_item=StorageItem(
key=event.key.decode("utf-8"),
value=event.value.decode("utf-8")),
event_type=event_type,
))
if self.events._loop:
self.events._loop._write_to_self()
except Exception as e:
logger.error(f"Error adding event: {e}")
self.cancel_event()


class Etcd3ClusterStorage(ClusterStorage):

def __init__(self,
cluster_uri: str,
cluster_name: str,
one_single_lease: bool = False):
cluster_uri = cluster_uri.replace("etcd://", "")
host, port = cluster_uri.rsplit(":", 1)
self._client = etcd3.client(host, port)
self._leases = {}
self._instance_lease = None
self._watch_handles = {}
self._one_single_lease = one_single_lease

def __del__(self):
self._watch_handles.clear()
self._client.close()

def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease:
if ttl <= 0:
return None
if self._one_single_lease:
return self._instance_lease
if key not in self._leases:
self._leases[key] = self.client.lease(ttl)
return self._leases[key]

@property
def client(self):
return self._client

async def start(self):
# nothing to do
...

async def stop(self):
# nothing to do
...

async def set(self,
key: str,
value: str,
overwrite_if_exists: bool = False,
ttl: int = -1) -> bool:
try:
lease = self._get_lease(key, ttl)
if not overwrite_if_exists:
return self.client.put_if_not_exists(key, value, lease=lease)
else:
self.client.put(key, value, lease=lease)
except etcd3.Etcd3Exception as e:
logger.error(f"Error setting key {key}: {e}")
return False
return True

async def get(self, key: str) -> str:
try:
data, meta = self.client.get(key)
return data.decode('utf-8') if data else None
except etcd3.Etcd3Exception as e:
logger.error(f"Error getting key {key}: {e}")
return None

async def delete(self, key: str) -> bool:
try:
self.client.delete(key)
except etcd3.Etcd3Exception as e:
logger.error(f"Error deleting key {key}: {e}")
return False
return True

async def expire(self, key: str, ttl: int) -> bool:
if ttl <= 0:
raise ValueError(f"TTL must be greater than 0, got {ttl}")
try:
lease = self._get_lease(key, ttl)
# TTL will be ignored since it can only be set when creating a lease
self.client.refresh_lease(lease_id=lease.id)
except etcd3.Etcd3Exception as e:
logger.error(f"Error refreshing lease {key}: {e}")
return False
return True

async def get_prefix(self,
key_prefix: str,
keys_only: bool = False) -> Dict[str, str]:
try:
resp = self.client.get_prefix(key_prefix, keys_only=keys_only)
return {
metadata.key.decode("utf-8"):
"" if keys_only else v.decode("utf-8")
for v, metadata in resp
}
except etcd3.Etcd3Exception as e:
logger.error(f"Error getting keys {key_prefix}: {e}")
return {}

async def watch(self, key_prefix: str) -> WatchEventQueue:
try:
if key_prefix in self._watch_handles:
return self._watch_handles[key_prefix]
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
watch_id = self.client.add_watch_prefix_callback(
key_prefix, watch_handle.add_event)
watch_handle.set_cancel_event(
lambda: self.client.cancel_watch(watch_id))
self._watch_handles[key_prefix] = watch_handle
return watch_handle
except etcd3.Etcd3Exception as e:
logger.error(f"Error watching key {key_prefix}: {e}")
return None

async def unwatch(self, key_prefix: str) -> None:
handle = self._watch_handles.pop(key_prefix)
if handle:
handle.cancel_event()
10 changes: 6 additions & 4 deletions tensorrt_llm/serve/disagg_auto_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class WorkerInfo:
role: ServerRole = ServerRole.CONTEXT


def get_worker_key_prefix(cluster_name: str):
def get_worker_key_prefix(cluster_name: str) -> str:
return f"/trtllm-disagg/{cluster_name}/workers"


Expand All @@ -47,13 +47,14 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):

def __del__(self):
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.unwatch_workers(),
asyncio.run_coroutine_threadsafe(self.stop(),
asyncio.get_event_loop())

async def start(self) -> None:
await self._cluster_storage.start()

async def stop(self) -> None:
await self.unwatch_workers()
await self._cluster_storage.stop()

async def cluster_info(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -104,8 +105,9 @@ async def watch_workers(self, get_existing_first: bool = True):
return workers

async def unwatch_workers(self) -> None:
await self._cluster_storage.unwatch(self.worker_key_prefix)
self._watch_handle = None
if self._watch_handle:
await self._cluster_storage.unwatch(self.worker_key_prefix)
self._watch_handle = None

async def get_worker_events(
self) -> List[Tuple[WorkerInfo, WatchEventType]]:
Expand Down
5 changes: 2 additions & 3 deletions tests/unittest/disaggregated/test_cluster_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def test_watch(self, storage_server_client, storage_client):
async def test_unwatch(self, storage_server_client, storage_client):
assert await storage_server_client.watch("test_key")
await storage_server_client.unwatch("test_key")
with pytest.raises(ValueError):
with pytest.raises(KeyError):
await storage_server_client.unwatch("test_key")

@pytest.mark.threadleak(enabled=False)
Expand Down Expand Up @@ -210,8 +210,7 @@ def storage_server(self):


class TestEtcdClusterStorage(TestClusterStorage):
# Disable this test until Etcd functionality is ready.
__test__ = False
__test__ = True

@pytest.fixture(scope="class")
def storage_server(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
INACTIVE_TIMEOUT = 4
HEARTBEAT_INTERVAL = 2

storage_types = ["http"]
storage_types = ["http", "etcd"]


def get_uri(storage_type):
Expand Down Expand Up @@ -77,7 +77,9 @@ async def cluster_manager(config, storage_server):


@pytest.mark.parametrize("config", storage_types, indirect=True)
@pytest.mark.threadleak(enabled=False)
@pytest.mark.threadleak(
enabled=False
) # ignore thread leak for python-etcd3 watch thread, there is no way to stop it
@pytest.mark.asyncio(scope="module")
async def test_init_workers_first(config, storage_server):
try:
Expand Down