Skip to content

Commit 5aa8d9b

Browse files
shihaobaishihaobai
and
shihaobai
authored
update health check (#690)
Co-authored-by: shihaobai <[email protected]>
1 parent d6b65de commit 5aa8d9b

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed

lightllm/server/api_server.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ def get_model_name():
110110
@app.head("/health", summary="Check server health")
111111
async def healthcheck(request: Request):
112112
if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
113-
return JSONResponse({"message": "Error"}, status_code=404)
113+
return JSONResponse({"message": "Error"}, status_code=503)
114+
from lightllm.utils.health_check import health_check, health_obj
114115

115-
from lightllm.utils.health_check import health_check
116-
117-
if await health_check(g_objs.args, g_objs.httpserver_manager, request):
116+
asyncio.create_task(health_check(g_objs.args, g_objs.httpserver_manager, None))
117+
if health_obj.is_health():
118118
return JSONResponse({"message": "Ok"}, status_code=200)
119119
else:
120-
return JSONResponse({"message": "Error"}, status_code=404)
120+
return JSONResponse({"message": "Error"}, status_code=503)
121121

122122

123123
@app.get("/token_load", summary="Get the current server's load of tokens")

lightllm/utils/health_check.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import base64
1+
import asyncio
22
import numpy as np
3+
from dataclasses import dataclass
34
from lightllm.server.sampling_params import SamplingParams
45
from lightllm.server.multimodal_params import MultimodalParams
56
from lightllm.server.httpserver.manager import HttpServerManager
@@ -14,13 +15,41 @@
1415
_g_health_req_id_gen.generate_id()
1516

1617

18+
@dataclass
19+
class HealthObj:
20+
_is_health: bool = True
21+
_is_health_checking: bool = False
22+
23+
def begin_check(self):
24+
self._is_health_checking = True
25+
26+
def end_check(self):
27+
self._is_health_checking = False
28+
29+
def set_unhealth(self):
30+
self._is_health = False
31+
32+
def set_health(self):
33+
self._is_health = True
34+
35+
def is_health(self):
36+
return self._is_health
37+
38+
def is_checking(self):
39+
return self._is_health_checking
40+
41+
42+
health_obj = HealthObj()
43+
44+
1745
async def health_check(args, httpserver_manager: HttpServerManager, request: Request):
46+
if health_obj.is_checking():
47+
return health_obj.is_health()
48+
health_obj.begin_check()
1849
try:
19-
2050
request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}}
2151
if args.run_mode == "prefill":
2252
request_dict["parameters"]["max_new_tokens"] = 1
23-
2453
prompt = request_dict.pop("inputs")
2554
sample_params_dict = request_dict["parameters"]
2655
sampling_params = SamplingParams(**sample_params_dict)
@@ -29,11 +58,21 @@ async def health_check(args, httpserver_manager: HttpServerManager, request: Req
2958
sampling_params.group_request_id = -_g_health_req_id_gen.generate_id() # health monitor 的 id 是负的
3059
multimodal_params_dict = request_dict.get("multimodal_params", {})
3160
multimodal_params = MultimodalParams(**multimodal_params_dict)
32-
3361
results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request)
34-
async for _, _, _, _ in results_generator:
35-
pass
36-
return True
62+
63+
async def check_timeout(results_generator):
64+
async for _, _, _, _ in results_generator:
65+
pass
66+
67+
try:
68+
await asyncio.wait_for(check_timeout(results_generator), timeout=88)
69+
health_obj.set_health()
70+
except asyncio.TimeoutError:
71+
health_obj.set_unhealth()
72+
return health_obj.is_health()
3773
except Exception as e:
3874
logger.exception(str(e))
39-
return False
75+
health_obj.set_unhealth()
76+
return health_obj.is_health()
77+
finally:
78+
health_obj.end_check()

0 commit comments

Comments
 (0)