@@ -226,6 +226,7 @@ def __init__(
226226 self .engines_id = [str (uuid .uuid4 ()) for i in range (0 , len (urls ))]
227227 self .added_timestamp = int (time .time ())
228228 self .unhealthy_endpoint_hashes = []
229+ self ._running = True
229230 if static_backend_health_checks :
230231 self .start_health_check_task ()
231232 self .prefill_model_labels = prefill_model_labels
@@ -250,10 +251,13 @@ def get_unhealthy_endpoint_hashes(self) -> list[str]:
250251 return unhealthy_endpoints
251252
252253 async def check_model_health (self ):
253- while True :
254+ while self . _running :
254255 try :
255256 self .unhealthy_endpoint_hashes = self .get_unhealthy_endpoint_hashes ()
256- time .sleep (60 )
257+ await asyncio .sleep (60 )
258+ except asyncio .CancelledError :
259+ logger .debug ("Health check task cancelled" )
260+ break
257261 except Exception as e :
258262 logger .error (e )
259263
@@ -340,6 +344,40 @@ async def initialize_client_sessions(self) -> None:
340344 timeout = aiohttp .ClientTimeout (total = None ),
341345 )
342346
347+ def close (self ):
348+ """
349+ Close the service discovery module and clean up health check resources.
350+ """
351+ self ._running = False
352+ if hasattr (self , "loop" ) and self .loop .is_running ():
353+ # Schedule a coroutine to gracefully shut down the event loop
354+ async def shutdown ():
355+ tasks = [
356+ t
357+ for t in asyncio .all_tasks (self .loop )
358+ if t is not asyncio .current_task ()
359+ ]
360+ for task in tasks :
361+ task .cancel ()
362+ await asyncio .gather (* tasks , return_exceptions = True )
363+ self .loop .stop ()
364+
365+ future = asyncio .run_coroutine_threadsafe (shutdown (), self .loop )
366+ try :
367+ future .result (timeout = 15.0 )
368+ except asyncio .TimeoutError :
369+ logger .warning (
370+ "Timed out waiting for shutdown(loop might already be closed)"
371+ )
372+ except Exception as e :
373+ logger .warning (f"Error during health check shutdown: { e } " )
374+
375+ if hasattr (self , "thread" ) and self .thread .is_alive ():
376+ self .thread .join (timeout = 5.0 )
377+
378+ if hasattr (self , "loop" ) and not self .loop .is_closed ():
379+ self .loop .close ()
380+
343381
344382class K8sPodIPServiceDiscovery (ServiceDiscovery ):
345383 def __init__ (
0 commit comments