Skip to content

Commit e3f4d75

Browse files
reasonsologovind-ramnarayan
authored andcommitted
[TRTLLM-7846][feat] implement etcd storage for disagg cluster (NVIDIA#8210)
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 019cd07 commit e3f4d75

File tree

4 files changed

+179
-16
lines changed

4 files changed

+179
-16
lines changed

tensorrt_llm/serve/cluster_storage.py

Lines changed: 167 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import abc
22
import asyncio
3-
import logging
43
import time
54
from dataclasses import dataclass
65
from enum import IntEnum
76
from functools import wraps
8-
from typing import Dict, List, Optional
7+
from typing import Callable, Dict, List, Optional
98

109
import aiohttp
10+
import etcd3
1111
from fastapi import FastAPI
1212
from fastapi.responses import JSONResponse
1313
from pydantic import BaseModel
1414

15-
logger = logging.getLogger('uvicorn.error')
15+
from tensorrt_llm.logger import logger
1616

1717

1818
class StorageItem(BaseModel):
@@ -91,7 +91,7 @@ async def delete(self, key: str) -> bool:
9191
async def watch(self, key_prefix: str) -> WatchEventQueue:
9292
...
9393

94-
# unwatch the key prefix, if the key prefix is not in the watch list, raise an error
94+
# unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError
9595
async def unwatch(self, key_prefix: str) -> None:
9696
...
9797

@@ -106,12 +106,16 @@ async def get_prefix(self,
106106
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
107107
if cluster_uri.startswith("http"):
108108
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
109+
elif cluster_uri.startswith("etcd"):
110+
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
109111
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
110112

111113

112-
def create_cluster_storage_client(cluster_uri, cluster_name):
114+
def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs):
113115
if cluster_uri.startswith("http"):
114-
return HttpClusterStorageClient(cluster_uri, cluster_name)
116+
return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs)
117+
elif cluster_uri.startswith("etcd"):
118+
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
115119
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
116120

117121

@@ -241,7 +245,7 @@ async def unwatch(self, key_prefix: str) -> None:
241245
if key_prefix in self._watch_handles:
242246
self._watch_handles.pop(key_prefix)
243247
else:
244-
raise ValueError(
248+
raise KeyError(
245249
f"Key prefix {key_prefix} not in watch list, {self._watch_handles.keys()}"
246250
)
247251

@@ -377,3 +381,159 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
377381
async def unwatch(self, key_prefix: str) -> None:
378382
raise NotImplementedError(
379383
"Unwatch functionality not implemented for HTTP client")
384+
385+
386+
class Etcd3WatchEventQueue(WatchEventQueue):
387+
388+
def __init__(self,
389+
key_prefix: str,
390+
cancel_event: Callable[[], None] = None):
391+
self.key_prefix = key_prefix
392+
self._cancel_event = cancel_event
393+
self.events = asyncio.Queue()
394+
395+
def cancel_event(self):
396+
if self._cancel_event:
397+
self._cancel_event()
398+
399+
def set_cancel_event(self, cancel_event: Callable[[], None]):
400+
self._cancel_event = cancel_event
401+
402+
def __del__(self):
403+
self.cancel_event()
404+
405+
def add_event(self, watch_resp):
406+
try:
407+
for event in watch_resp.events:
408+
# Event type is not in public interface of etcd3
409+
event_type = WatchEventType.SET if "Put" in event.__class__.__name__ else WatchEventType.DELETE
410+
self.events.put_nowait(
411+
WatchEvent(
412+
storage_item=StorageItem(
413+
key=event.key.decode("utf-8"),
414+
value=event.value.decode("utf-8")),
415+
event_type=event_type,
416+
))
417+
if self.events._loop:
418+
self.events._loop._write_to_self()
419+
except Exception as e:
420+
logger.error(f"Error adding event: {e}")
421+
self.cancel_event()
422+
423+
424+
class Etcd3ClusterStorage(ClusterStorage):
425+
426+
def __init__(self,
427+
cluster_uri: str,
428+
cluster_name: str,
429+
one_single_lease: bool = False):
430+
cluster_uri = cluster_uri.replace("etcd://", "")
431+
host, port = cluster_uri.rsplit(":", 1)
432+
self._client = etcd3.client(host, port)
433+
self._leases = {}
434+
self._instance_lease = None
435+
self._watch_handles = {}
436+
self._one_single_lease = one_single_lease
437+
438+
def __del__(self):
439+
self._watch_handles.clear()
440+
self._client.close()
441+
442+
def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease:
443+
if ttl <= 0:
444+
return None
445+
if self._one_single_lease:
446+
return self._instance_lease
447+
if key not in self._leases:
448+
self._leases[key] = self.client.lease(ttl)
449+
return self._leases[key]
450+
451+
@property
452+
def client(self):
453+
return self._client
454+
455+
async def start(self):
456+
# nothing to do
457+
...
458+
459+
async def stop(self):
460+
# nothing to do
461+
...
462+
463+
async def set(self,
464+
key: str,
465+
value: str,
466+
overwrite_if_exists: bool = False,
467+
ttl: int = -1) -> bool:
468+
try:
469+
lease = self._get_lease(key, ttl)
470+
if not overwrite_if_exists:
471+
return self.client.put_if_not_exists(key, value, lease=lease)
472+
else:
473+
self.client.put(key, value, lease=lease)
474+
except etcd3.Etcd3Exception as e:
475+
logger.error(f"Error setting key {key}: {e}")
476+
return False
477+
return True
478+
479+
async def get(self, key: str) -> str:
480+
try:
481+
data, meta = self.client.get(key)
482+
return data.decode('utf-8') if data else None
483+
except etcd3.Etcd3Exception as e:
484+
logger.error(f"Error getting key {key}: {e}")
485+
return None
486+
487+
async def delete(self, key: str) -> bool:
488+
try:
489+
self.client.delete(key)
490+
except etcd3.Etcd3Exception as e:
491+
logger.error(f"Error deleting key {key}: {e}")
492+
return False
493+
return True
494+
495+
async def expire(self, key: str, ttl: int) -> bool:
496+
if ttl <= 0:
497+
raise ValueError(f"TTL must be greater than 0, got {ttl}")
498+
try:
499+
lease = self._get_lease(key, ttl)
500+
# TTL will be ignored since it can only be set when creating a lease
501+
self.client.refresh_lease(lease_id=lease.id)
502+
except etcd3.Etcd3Exception as e:
503+
logger.error(f"Error refreshing lease {key}: {e}")
504+
return False
505+
return True
506+
507+
async def get_prefix(self,
508+
key_prefix: str,
509+
keys_only: bool = False) -> Dict[str, str]:
510+
try:
511+
resp = self.client.get_prefix(key_prefix, keys_only=keys_only)
512+
return {
513+
metadata.key.decode("utf-8"):
514+
"" if keys_only else v.decode("utf-8")
515+
for v, metadata in resp
516+
}
517+
except etcd3.Etcd3Exception as e:
518+
logger.error(f"Error getting keys {key_prefix}: {e}")
519+
return {}
520+
521+
async def watch(self, key_prefix: str) -> WatchEventQueue:
522+
try:
523+
if key_prefix in self._watch_handles:
524+
return self._watch_handles[key_prefix]
525+
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
526+
watch_id = self.client.add_watch_prefix_callback(
527+
key_prefix, watch_handle.add_event)
528+
watch_handle.set_cancel_event(
529+
lambda: self.client.cancel_watch(watch_id))
530+
self._watch_handles[key_prefix] = watch_handle
531+
return watch_handle
532+
except etcd3.Etcd3Exception as e:
533+
logger.error(f"Error watching key {key_prefix}: {e}")
534+
return None
535+
536+
async def unwatch(self, key_prefix: str) -> None:
537+
handle = self._watch_handles.pop(key_prefix)
538+
if handle:
539+
handle.cancel_event()

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class WorkerInfo:
2121
role: ServerRole = ServerRole.CONTEXT
2222

2323

24-
def get_worker_key_prefix(cluster_name: str):
24+
def get_worker_key_prefix(cluster_name: str) -> str:
2525
return f"/trtllm-disagg/{cluster_name}/workers"
2626

2727

@@ -47,13 +47,14 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
4747

4848
def __del__(self):
4949
if asyncio.get_event_loop():
50-
asyncio.run_coroutine_threadsafe(self.unwatch_workers(),
50+
asyncio.run_coroutine_threadsafe(self.stop(),
5151
asyncio.get_event_loop())
5252

5353
async def start(self) -> None:
5454
await self._cluster_storage.start()
5555

5656
async def stop(self) -> None:
57+
await self.unwatch_workers()
5758
await self._cluster_storage.stop()
5859

5960
async def cluster_info(self) -> Dict[str, Any]:
@@ -104,8 +105,9 @@ async def watch_workers(self, get_existing_first: bool = True):
104105
return workers
105106

106107
async def unwatch_workers(self) -> None:
107-
await self._cluster_storage.unwatch(self.worker_key_prefix)
108-
self._watch_handle = None
108+
if self._watch_handle:
109+
await self._cluster_storage.unwatch(self.worker_key_prefix)
110+
self._watch_handle = None
109111

110112
async def get_worker_events(
111113
self) -> List[Tuple[WorkerInfo, WatchEventType]]:

tests/unittest/disaggregated/test_cluster_storage.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def test_watch(self, storage_server_client, storage_client):
132132
async def test_unwatch(self, storage_server_client, storage_client):
133133
assert await storage_server_client.watch("test_key")
134134
await storage_server_client.unwatch("test_key")
135-
with pytest.raises(ValueError):
135+
with pytest.raises(KeyError):
136136
await storage_server_client.unwatch("test_key")
137137

138138
@pytest.mark.threadleak(enabled=False)
@@ -210,8 +210,7 @@ def storage_server(self):
210210

211211

212212
class TestEtcdClusterStorage(TestClusterStorage):
213-
# Disable this test until Etcd functionality is ready.
214-
__test__ = False
213+
__test__ = True
215214

216215
@pytest.fixture(scope="class")
217216
def storage_server(self):

tests/unittest/disaggregated/test_disagg_cluster_manager_worker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
INACTIVE_TIMEOUT = 4
1919
HEARTBEAT_INTERVAL = 2
2020

21-
storage_types = ["http"]
21+
storage_types = ["http", "etcd"]
2222

2323

2424
def get_uri(storage_type):
@@ -77,7 +77,9 @@ async def cluster_manager(config, storage_server):
7777

7878

7979
@pytest.mark.parametrize("config", storage_types, indirect=True)
80-
@pytest.mark.threadleak(enabled=False)
80+
@pytest.mark.threadleak(
81+
enabled=False
82+
) # ignore thread leak for python-etcd3 watch thread, there is no way to stop it
8183
@pytest.mark.asyncio(scope="module")
8284
async def test_init_workers_first(config, storage_server):
8385
try:

0 commit comments

Comments
 (0)