11import  abc 
22import  asyncio 
3- import  logging 
43import  time 
54from  dataclasses  import  dataclass 
65from  enum  import  IntEnum 
76from  functools  import  wraps 
8- from  typing  import  Dict , List , Optional 
7+ from  typing  import  Callable ,  Dict , List , Optional 
98
109import  aiohttp 
10+ import  etcd3 
1111from  fastapi  import  FastAPI 
1212from  fastapi .responses  import  JSONResponse 
1313from  pydantic  import  BaseModel 
1414
15- logger  =   logging . getLogger ( 'uvicorn.error' ) 
15+ from   tensorrt_llm . logger  import   logger 
1616
1717
1818class  StorageItem (BaseModel ):
@@ -91,7 +91,7 @@ async def delete(self, key: str) -> bool:
9191    async  def  watch (self , key_prefix : str ) ->  WatchEventQueue :
9292        ...
9393
94-     # unwatch the key prefix, if the key prefix is not in the watch list, raise an error  
94+     # unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError  
9595    async  def  unwatch (self , key_prefix : str ) ->  None :
9696        ...
9797
@@ -106,12 +106,16 @@ async def get_prefix(self,
106106def  create_cluster_storage (cluster_uri , cluster_name , ** kwargs ):
107107    if  cluster_uri .startswith ("http" ):
108108        return  HttpClusterStorageServer (cluster_uri , cluster_name , ** kwargs )
109+     elif  cluster_uri .startswith ("etcd" ):
110+         return  Etcd3ClusterStorage (cluster_uri , cluster_name , ** kwargs )
109111    raise  ValueError (f"Invalid cluster storage URI: { cluster_uri }  )
110112
111113
112- def  create_cluster_storage_client (cluster_uri , cluster_name ):
114+ def  create_cluster_storage_client (cluster_uri , cluster_name ,  ** kwargs ):
113115    if  cluster_uri .startswith ("http" ):
114-         return  HttpClusterStorageClient (cluster_uri , cluster_name )
116+         return  HttpClusterStorageClient (cluster_uri , cluster_name , ** kwargs )
117+     elif  cluster_uri .startswith ("etcd" ):
118+         return  Etcd3ClusterStorage (cluster_uri , cluster_name , ** kwargs )
115119    raise  ValueError (f"Invalid cluster storage URI: { cluster_uri }  )
116120
117121
@@ -241,7 +245,7 @@ async def unwatch(self, key_prefix: str) -> None:
241245            if  key_prefix  in  self ._watch_handles :
242246                self ._watch_handles .pop (key_prefix )
243247            else :
244-                 raise  ValueError (
248+                 raise  KeyError (
245249                    f"Key prefix { key_prefix } { self ._watch_handles .keys ()}  
246250                )
247251
@@ -377,3 +381,159 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
377381    async  def  unwatch (self , key_prefix : str ) ->  None :
378382        raise  NotImplementedError (
379383            "Unwatch functionality not implemented for HTTP client" )
384+ 
385+ 
386+ class  Etcd3WatchEventQueue (WatchEventQueue ):
387+ 
388+     def  __init__ (self ,
389+                  key_prefix : str ,
390+                  cancel_event : Callable [[], None ] =  None ):
391+         self .key_prefix  =  key_prefix 
392+         self ._cancel_event  =  cancel_event 
393+         self .events  =  asyncio .Queue ()
394+ 
395+     def  cancel_event (self ):
396+         if  self ._cancel_event :
397+             self ._cancel_event ()
398+ 
399+     def  set_cancel_event (self , cancel_event : Callable [[], None ]):
400+         self ._cancel_event  =  cancel_event 
401+ 
402+     def  __del__ (self ):
403+         self .cancel_event ()
404+ 
405+     def  add_event (self , watch_resp ):
406+         try :
407+             for  event  in  watch_resp .events :
408+                 # Event type is not in public interface of etcd3 
409+                 event_type  =  WatchEventType .SET  if  "Put"  in  event .__class__ .__name__  else  WatchEventType .DELETE 
410+                 self .events .put_nowait (
411+                     WatchEvent (
412+                         storage_item = StorageItem (
413+                             key = event .key .decode ("utf-8" ),
414+                             value = event .value .decode ("utf-8" )),
415+                         event_type = event_type ,
416+                     ))
417+             if  self .events ._loop :
418+                 self .events ._loop ._write_to_self ()
419+         except  Exception  as  e :
420+             logger .error (f"Error adding event: { e }  )
421+             self .cancel_event ()
422+ 
423+ 
424+ class  Etcd3ClusterStorage (ClusterStorage ):
425+ 
426+     def  __init__ (self ,
427+                  cluster_uri : str ,
428+                  cluster_name : str ,
429+                  one_single_lease : bool  =  False ):
430+         cluster_uri  =  cluster_uri .replace ("etcd://" , "" )
431+         host , port  =  cluster_uri .rsplit (":" , 1 )
432+         self ._client  =  etcd3 .client (host , port )
433+         self ._leases  =  {}
434+         self ._instance_lease  =  None 
435+         self ._watch_handles  =  {}
436+         self ._one_single_lease  =  one_single_lease 
437+ 
438+     def  __del__ (self ):
439+         self ._watch_handles .clear ()
440+         self ._client .close ()
441+ 
442+     def  _get_lease (self , key : str , ttl : int  =  - 1 ) ->  etcd3 .Lease :
443+         if  ttl  <=  0 :
444+             return  None 
445+         if  self ._one_single_lease :
446+             return  self ._instance_lease 
447+         if  key  not  in self ._leases :
448+             self ._leases [key ] =  self .client .lease (ttl )
449+         return  self ._leases [key ]
450+ 
451+     @property  
452+     def  client (self ):
453+         return  self ._client 
454+ 
455+     async  def  start (self ):
456+         # nothing to do 
457+         ...
458+ 
459+     async  def  stop (self ):
460+         # nothing to do 
461+         ...
462+ 
463+     async  def  set (self ,
464+                   key : str ,
465+                   value : str ,
466+                   overwrite_if_exists : bool  =  False ,
467+                   ttl : int  =  - 1 ) ->  bool :
468+         try :
469+             lease  =  self ._get_lease (key , ttl )
470+             if  not  overwrite_if_exists :
471+                 return  self .client .put_if_not_exists (key , value , lease = lease )
472+             else :
473+                 self .client .put (key , value , lease = lease )
474+         except  etcd3 .Etcd3Exception  as  e :
475+             logger .error (f"Error setting key { key } { e }  )
476+             return  False 
477+         return  True 
478+ 
479+     async  def  get (self , key : str ) ->  str :
480+         try :
481+             data , meta  =  self .client .get (key )
482+             return  data .decode ('utf-8' ) if  data  else  None 
483+         except  etcd3 .Etcd3Exception  as  e :
484+             logger .error (f"Error getting key { key } { e }  )
485+             return  None 
486+ 
487+     async  def  delete (self , key : str ) ->  bool :
488+         try :
489+             self .client .delete (key )
490+         except  etcd3 .Etcd3Exception  as  e :
491+             logger .error (f"Error deleting key { key } { e }  )
492+             return  False 
493+         return  True 
494+ 
495+     async  def  expire (self , key : str , ttl : int ) ->  bool :
496+         if  ttl  <=  0 :
497+             raise  ValueError (f"TTL must be greater than 0, got { ttl }  )
498+         try :
499+             lease  =  self ._get_lease (key , ttl )
500+             # TTL will be ignored since it can only be set when creating a lease 
501+             self .client .refresh_lease (lease_id = lease .id )
502+         except  etcd3 .Etcd3Exception  as  e :
503+             logger .error (f"Error refreshing lease { key } { e }  )
504+             return  False 
505+         return  True 
506+ 
507+     async  def  get_prefix (self ,
508+                          key_prefix : str ,
509+                          keys_only : bool  =  False ) ->  Dict [str , str ]:
510+         try :
511+             resp  =  self .client .get_prefix (key_prefix , keys_only = keys_only )
512+             return  {
513+                 metadata .key .decode ("utf-8" ):
514+                 ""  if  keys_only  else  v .decode ("utf-8" )
515+                 for  v , metadata  in  resp 
516+             }
517+         except  etcd3 .Etcd3Exception  as  e :
518+             logger .error (f"Error getting keys { key_prefix } { e }  )
519+             return  {}
520+ 
521+     async  def  watch (self , key_prefix : str ) ->  WatchEventQueue :
522+         try :
523+             if  key_prefix  in  self ._watch_handles :
524+                 return  self ._watch_handles [key_prefix ]
525+             watch_handle  =  Etcd3WatchEventQueue (key_prefix = key_prefix )
526+             watch_id  =  self .client .add_watch_prefix_callback (
527+                 key_prefix , watch_handle .add_event )
528+             watch_handle .set_cancel_event (
529+                 lambda : self .client .cancel_watch (watch_id ))
530+             self ._watch_handles [key_prefix ] =  watch_handle 
531+             return  watch_handle 
532+         except  etcd3 .Etcd3Exception  as  e :
533+             logger .error (f"Error watching key { key_prefix } { e }  )
534+             return  None 
535+ 
536+     async  def  unwatch (self , key_prefix : str ) ->  None :
537+         handle  =  self ._watch_handles .pop (key_prefix )
538+         if  handle :
539+             handle .cancel_event ()
0 commit comments