|
1 | 1 | import abc |
2 | 2 | import asyncio |
3 | | -import logging |
4 | 3 | import time |
5 | 4 | from dataclasses import dataclass |
6 | 5 | from enum import IntEnum |
7 | 6 | from functools import wraps |
8 | | -from typing import Dict, List, Optional |
| 7 | +from typing import Callable, Dict, List, Optional |
9 | 8 |
|
10 | 9 | import aiohttp |
| 10 | +import etcd3 |
11 | 11 | from fastapi import FastAPI |
12 | 12 | from fastapi.responses import JSONResponse |
13 | 13 | from pydantic import BaseModel |
14 | 14 |
|
15 | | -logger = logging.getLogger('uvicorn.error') |
| 15 | +from tensorrt_llm.logger import logger |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class StorageItem(BaseModel): |
@@ -106,12 +106,16 @@ async def get_prefix(self, |
106 | 106 | def create_cluster_storage(cluster_uri, cluster_name, **kwargs): |
107 | 107 | if cluster_uri.startswith("http"): |
108 | 108 | return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs) |
| 109 | + elif cluster_uri.startswith("etcd"): |
| 110 | + return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs) |
109 | 111 | raise ValueError(f"Invalid cluster storage URI: {cluster_uri}") |
110 | 112 |
|
111 | 113 |
|
112 | | -def create_cluster_storage_client(cluster_uri, cluster_name): |
| 114 | +def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs): |
113 | 115 | if cluster_uri.startswith("http"): |
114 | | - return HttpClusterStorageClient(cluster_uri, cluster_name) |
| 116 | + return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs) |
| 117 | + elif cluster_uri.startswith("etcd"): |
| 118 | + return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs) |
115 | 119 | raise ValueError(f"Invalid cluster storage URI: {cluster_uri}") |
116 | 120 |
|
117 | 121 |
|
@@ -241,7 +245,7 @@ async def unwatch(self, key_prefix: str) -> None: |
241 | 245 | if key_prefix in self._watch_handles: |
242 | 246 | self._watch_handles.pop(key_prefix) |
243 | 247 | else: |
244 | | - raise ValueError( |
| 248 | + raise KeyError( |
245 | 249 | f"Key prefix {key_prefix} not in watch list, {self._watch_handles.keys()}" |
246 | 250 | ) |
247 | 251 |
|
@@ -377,3 +381,151 @@ async def watch(self, key_prefix: str) -> WatchEventQueue: |
377 | 381 | async def unwatch(self, key_prefix: str) -> None: |
378 | 382 | raise NotImplementedError( |
379 | 383 | "Unwatch functionality not implemented for HTTP client") |
| 384 | + |
| 385 | + |
| 386 | +class Etcd3WatchEventQueue(WatchEventQueue): |
| 387 | + |
| 388 | + def __init__(self, |
| 389 | + key_prefix: str, |
| 390 | + cancel_event: Callable[[], None] = None): |
| 391 | + self.key_prefix = key_prefix |
| 392 | + self._cancel_event = cancel_event |
| 393 | + self.events = asyncio.Queue() |
| 394 | + |
| 395 | + def cancel_event(self): |
| 396 | + if self._cancel_event: |
| 397 | + self._cancel_event() |
| 398 | + |
| 399 | + def set_cancel_event(self, cancel_event: Callable[[], None]): |
| 400 | + self._cancel_event = cancel_event |
| 401 | + |
| 402 | + def __del__(self): |
| 403 | + self.cancel_event() |
| 404 | + |
| 405 | + def add_event(self, watch_resp): |
| 406 | + try: |
| 407 | + for event in watch_resp.events: |
| 408 | + # Event type is not in public interface of etcd3 |
| 409 | + event_type = WatchEventType.SET if "Put" in event.__class__.__name__ else WatchEventType.DELETE |
| 410 | + self.events.put_nowait( |
| 411 | + WatchEvent( |
| 412 | + storage_item=StorageItem( |
| 413 | + key=event.key.decode("utf-8"), |
| 414 | + value=event.value.decode("utf-8")), |
| 415 | + event_type=event_type, |
| 416 | + )) |
| 417 | + if self.events._loop: |
| 418 | + self.events._loop._write_to_self() |
| 419 | + except Exception as e: |
| 420 | + logger.error(f"Error adding event: {e}") |
| 421 | + self.cancel_event() |
| 422 | + |
| 423 | + |
| 424 | +class Etcd3ClusterStorage(ClusterStorage): |
| 425 | + |
| 426 | + def __init__(self, |
| 427 | + cluster_uri: str, |
| 428 | + cluster_name: str, |
| 429 | + one_single_lease: bool = False): |
| 430 | + cluster_uri = cluster_uri.replace("etcd://", "") |
| 431 | + host, port = cluster_uri.rsplit(":", 1) |
| 432 | + self._client = etcd3.client(host, port) |
| 433 | + self._leases = {} |
| 434 | + self._instance_lease = None |
| 435 | + self._watch_handles = {} |
| 436 | + self._one_single_lease = one_single_lease |
| 437 | + |
| 438 | + def __del__(self): |
| 439 | + self._watch_handles.clear() |
| 440 | + self._client.close() |
| 441 | + |
| 442 | + def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease: |
| 443 | + if ttl <= 0: |
| 444 | + return None |
| 445 | + if self._one_single_lease: |
| 446 | + return self._instance_lease |
| 447 | + if key not in self._leases: |
| 448 | + self._leases[key] = self.client.lease(ttl) |
| 449 | + return self._leases[key] |
| 450 | + |
| 451 | + @property |
| 452 | + def client(self): |
| 453 | + return self._client |
| 454 | + |
| 455 | + async def set(self, |
| 456 | + key: str, |
| 457 | + value: str, |
| 458 | + overwrite_if_exists: bool = False, |
| 459 | + ttl: int = -1) -> bool: |
| 460 | + try: |
| 461 | + lease = self._get_lease(key, ttl) |
| 462 | + if not overwrite_if_exists: |
| 463 | + return self.client.put_if_not_exists(key, value, lease=lease) |
| 464 | + else: |
| 465 | + self.client.put(key, value, lease=lease) |
| 466 | + except etcd3.Etcd3Exception as e: |
| 467 | + logger.error(f"Error setting key {key}: {e}") |
| 468 | + return False |
| 469 | + return True |
| 470 | + |
| 471 | + async def get(self, key: str) -> str: |
| 472 | + try: |
| 473 | + data, meta = self.client.get(key) |
| 474 | + return data.decode('utf-8') if data else None |
| 475 | + except etcd3.Etcd3Exception as e: |
| 476 | + logger.error(f"Error getting key {key}: {e}") |
| 477 | + return None |
| 478 | + |
| 479 | + async def delete(self, key: str) -> bool: |
| 480 | + try: |
| 481 | + self.client.delete(key) |
| 482 | + except etcd3.Etcd3Exception as e: |
| 483 | + logger.error(f"Error deleting key {key}: {e}") |
| 484 | + return False |
| 485 | + return True |
| 486 | + |
| 487 | + async def expire(self, key: str, ttl: int) -> bool: |
| 488 | + if ttl <= 0: |
| 489 | + raise ValueError(f"TTL must be greater than 0, got {ttl}") |
| 490 | + try: |
| 491 | + lease = self._get_lease(key, ttl) |
| 492 | + # TTL will be ignored since it can only be set when creating a lease |
| 493 | + self.client.refresh_lease(lease_id=lease.id) |
| 494 | + except etcd3.Etcd3Exception as e: |
| 495 | + logger.error(f"Error refreshing lease {key}: {e}") |
| 496 | + return False |
| 497 | + return True |
| 498 | + |
| 499 | + async def get_prefix(self, |
| 500 | + key_prefix: str, |
| 501 | + keys_only: bool = False) -> Dict[str, str]: |
| 502 | + try: |
| 503 | + resp = self.client.get_prefix(key_prefix, keys_only=keys_only) |
| 504 | + return { |
| 505 | + metadata.key.decode("utf-8"): |
| 506 | + "" if keys_only else v.decode("utf-8") |
| 507 | + for v, metadata in resp |
| 508 | + } |
| 509 | + except etcd3.Etcd3Exception as e: |
| 510 | + logger.error(f"Error getting keys {key_prefix}: {e}") |
| 511 | + return {} |
| 512 | + |
| 513 | + async def watch(self, key_prefix: str) -> WatchEventQueue: |
| 514 | + try: |
| 515 | + if key_prefix in self._watch_handles: |
| 516 | + return self._watch_handles[key_prefix] |
| 517 | + watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix) |
| 518 | + watch_id = self.client.add_watch_prefix_callback( |
| 519 | + key_prefix, watch_handle.add_event) |
| 520 | + watch_handle.set_cancel_event( |
| 521 | + lambda: self.client.cancel_watch(watch_id)) |
| 522 | + self._watch_handles[key_prefix] = watch_handle |
| 523 | + return watch_handle |
| 524 | + except etcd3.Etcd3Exception as e: |
| 525 | + logger.error(f"Error watching key {key_prefix}: {e}") |
| 526 | + return None |
| 527 | + |
| 528 | + async def unwatch(self, key_prefix: str) -> None: |
| 529 | + handle = self._watch_handles.pop(key_prefix) |
| 530 | + if handle: |
| 531 | + handle.cancel_event() |
0 commit comments