Skip to content

Commit

Permalink
Return meaningful finish reason instead of a boolean (#285)
Browse files Browse the repository at this point in the history
Return a string to describe about the finish reason.

Related to #130

Due to in codebase there exists the usage like

```
req.aborted and req.req_status == ReqRunStatus.WAIT_IN_QUEUE
```

It is not suitable to merge `ReqRunStatus` and `FinishStatus`, so I
separate an individual class to indicate the finish reason.

---------

Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
zeyugao and hiworldwzj authored Jan 8, 2024
1 parent 3030409 commit 2a56868
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 33 deletions.
8 changes: 5 additions & 3 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def generate(request: Request) -> Response:
prompt_logprobs = None
prompt_token_ids = None
is_first_metadata = True
async for request_output, metadata, _ in results_generator:
async for request_output, metadata, finish_status in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await httpserver_manager.abort(request_id)
Expand All @@ -132,6 +132,7 @@ async def generate(request: Request) -> Response:
ret = {
"generated_text": ["".join(final_output)],
"count_output_tokens": count_output_tokens,
"finish_reason": finish_status.get_finish_reason()
}
if return_details:
ret["tokens"] = tokens
Expand Down Expand Up @@ -164,7 +165,7 @@ async def generate_stream(request: Request) -> Response:

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output, metadata, finished in results_generator:
async for request_output, metadata, finish_status in results_generator:
ret = {
"token": {
"id": metadata.get("id", None),
Expand All @@ -173,7 +174,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
"special": False
},
"generated_text": None,
"finished": finished,
"finished": finish_status.is_finished(),
"finish_reason": finish_status.get_finish_reason(),
"details": None
}

Expand Down
8 changes: 4 additions & 4 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import zmq
import zmq.asyncio
from ..io_struct import BatchTokenIdOut, ReqDetokenizationState, BatchStrOut, AbortReq
from ..io_struct import BatchTokenIdOut, ReqDetokenizationState, BatchStrOut, AbortReq, FinishStatus
from typing import Union
from .decode import decode_token
from ..tokenizer import get_tokenizer
Expand Down Expand Up @@ -53,7 +53,7 @@ async def handle_loop(self):

if isinstance(recv_obj, BatchTokenIdOut):
new_batch_str_out = BatchStrOut()
for req_id, new_token_id, new_gen_metadata, finished, abort in recv_obj.reqs_infs:
for req_id, new_token_id, new_gen_metadata, finish_status in recv_obj.reqs_infs:
if req_id not in self.req_id_to_out:
continue
req_out:ReqDetokenizationState = self.req_id_to_out[req_id]
Expand All @@ -73,8 +73,8 @@ async def handle_loop(self):
else:
new_text = out_text[len(req_out.output_str):]
req_out.output_str = out_text
new_batch_str_out.reqs_infs.append((req_id, new_text, new_gen_metadata, True if abort else finished, abort))
if finished or abort:
new_batch_str_out.reqs_infs.append((req_id, new_text, new_gen_metadata, finish_status))
if FinishStatus(finish_status).is_finished():
try:
del self.req_id_to_out[req_id]
except:
Expand Down
15 changes: 8 additions & 7 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from ..tokenizer import get_tokenizer
from ..io_struct import BatchStrOut, AbortReq
from ..io_struct import BatchStrOut, AbortReq, FinishStatus
from ..embed_cache.utils import get_shm_name_data, create_shm

class HttpServerManager:
Expand Down Expand Up @@ -165,11 +165,11 @@ async def generate(self, prompt, sampling_params, request_id, multimodal_params)
if len(req_status.out_token_info_list) == 0:
continue

for out_str, metadata, finished in req_status.out_token_info_list:
for out_str, metadata, finish_status in req_status.out_token_info_list:
metadata["prompt_tokens"] = prompt_tokens
yield out_str, metadata, finished
yield out_str, metadata, finish_status

if finished:
if finish_status.is_finished():
try:
del self.req_id_to_out_inf[request_id]
await self._release_multimodal_resources(multimodal_params)
Expand Down Expand Up @@ -198,12 +198,13 @@ async def handle_loop(self):
assert isinstance(
recv_ans, BatchStrOut
), f"error recv type {type(recv_ans)}"
for req_id, text, metadata, finished, abort in recv_ans.reqs_infs:
for req_id, text, metadata, finish_status in recv_ans.reqs_infs:
finish_status = FinishStatus(finish_status)
try:
if not abort:
if not finish_status.is_aborted():
req_status : ReqStatus = self.req_id_to_out_inf[req_id]
async with req_status.lock:
req_status.out_token_info_list.append((text, metadata, finished))
req_status.out_token_info_list.append((text, metadata, finish_status))
req_status.event.set()
else:
del self.req_id_to_out_inf[req_id]
Expand Down
43 changes: 32 additions & 11 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .sampling_params import SamplingParams
from .multimodal_params import MultimodalParams
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import asyncio
import enum

Expand All @@ -12,6 +12,28 @@ class ReqRunStatus(enum.Enum):
RERUNNING_FROM_KVKEEP = 4 # 从暂停中恢复
RERUNNING_FROM_OFFLOAD = 5 # 从卸载KV中恢复

class FinishStatus(enum.Enum):
NO_FINISH = 0 # 没有结束
FINISHED_STOP = 1 # 因为遇到了STOP token 而结束
FINISHED_LENGTH = 2 # 因为长度达到了最大长度而结束
FINISHED_ABORT = 3 # 因为请求被中止而结束

def is_finished(self):
return 1 <= self.value <= 3

def is_aborted(self):
return self == FinishStatus.FINISHED_ABORT

def get_finish_reason(self):
if self == FinishStatus.FINISHED_STOP:
finish_reason = "stop"
elif self == FinishStatus.FINISHED_LENGTH:
finish_reason = "length"
elif self == FinishStatus.FINISHED_ABORT:
finish_reason = "abort"
else:
finish_reason = None
return finish_reason

class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams, prompt_cache_len=0, prompt_cache_req_id=None):
Expand All @@ -23,10 +45,9 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multim
self.multimodal_params = multimodal_params
self.output_ids = []
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False

self.req_status = ReqRunStatus.WAIT_IN_QUEUE
self.finish_status = FinishStatus.NO_FINISH
self.cur_kv_len = 0 # 当前已经占用掉 token 的 kv len 长度
self.prompt_cache_len = prompt_cache_len # 可以复用的一些公共 prompt 头对应的 kv cache 长度, 只有 splitfuse 模式当前才实际使用
self.prompt_cache_req_id = prompt_cache_req_id # 对应的可复用的请求的 id,方便初始化的时候,将其 kv cache 复制到当前请求中, 默认值 为 None
Expand Down Expand Up @@ -215,13 +236,13 @@ def mark_and_get_finished_req_and_preupdate_status(self, eos_id):
unfinished_req_ids, finished_req_ids = [], []
for req in self.reqs:
if req.stop_sequences_matched():
req.has_generate_finished = True
if len(req.output_ids) >= 1 and req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
req.has_generate_finished = True
if len(req.output_ids) >= req.max_output_len or req.aborted:
req.has_generate_finished = True
req.finish_status = FinishStatus.FINISHED_STOP
elif len(req.output_ids) >= 1 and req.output_ids[-1] == eos_id and req.sample_params.ignore_eos is False:
req.finish_status = FinishStatus.FINISHED_STOP
elif len(req.output_ids) >= req.max_output_len:
req.finish_status = FinishStatus.FINISHED_LENGTH

if req.has_generate_finished:
if req.finish_status.is_finished():
finished_req_ids.append(req.request_id)
# 标记的时候,也同时更新一些这些请求被移除掉的更新量,有点dirty
self.batch_used_tokens -= req.get_used_tokens()
Expand Down Expand Up @@ -263,11 +284,11 @@ def __repr__(self):

class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, int, Dict, int]] = [] # [req_id, new_token_id, gen_metadata, finish_status]

class BatchStrOut:
def __init__(self):
self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
self.reqs_infs: List[Tuple[str, str, Dict, int]] = [] # [req_id, token_str, gen_metadata, finish_status]

class AbortReq:
def __init__(self, req_id):
Expand Down
11 changes: 5 additions & 6 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .req_queue import ReqQueue
from rpyc.utils.classic import obtain
from lightllm.utils.infer_utils import calculate_time
from ..io_struct import BatchTokenIdOut, AbortReq, ReqRunStatus
from ..io_struct import BatchTokenIdOut, AbortReq, ReqRunStatus, FinishStatus
from .stats import Stats
from .pause_strategy import Fcfs, select_paused_reqs
from ..tokenizer import get_tokenizer
Expand Down Expand Up @@ -133,12 +133,10 @@ async def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
req.finish_status = FinishStatus.FINISHED_ABORT
for req in self.req_queue.waiting_req_list:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
req.finish_status = FinishStatus.FINISHED_ABORT
return

async def loop_for_fwd(self,):
Expand Down Expand Up @@ -343,7 +341,8 @@ def _send_to_detokenization_proc(self, batch: Batch, req_ans):
for req_id, (_, _, new_token_id, new_gen_metadata) in req_ans.items():
req = batch.id_to_reqs[req_id]
if new_token_id is not None:
batch_out.reqs_infs.append((req_id, new_token_id, new_gen_metadata, req.has_generate_finished, req.aborted))
# req.finish_status 传输 value值 不传送对象,可以减少序列化对象的大小。
batch_out.reqs_infs.append((req_id, new_token_id, new_gen_metadata, req.finish_status.value))

self.send_to_detokenization.send_pyobj(batch_out)
return
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/router/req_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..io_struct import Batch, Req
from lightllm.utils.infer_utils import calculate_time
from lightllm.server.io_struct import Req
from lightllm.server.io_struct import ReqRunStatus
from lightllm.server.io_struct import ReqRunStatus, FinishStatus

class ReqQueue:

Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_new_batch(self, current_batch:Batch):
new_batch_first_router_need_tokens = 0 # 主要是对 prefill 或者 splitfuse 大块计算时候的限制
aborted_count = 0
for req in self.waiting_req_list:
if req.aborted and req.req_status == ReqRunStatus.WAIT_IN_QUEUE:
if req.finish_status.is_aborted() and req.req_status == ReqRunStatus.WAIT_IN_QUEUE:
# 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉.
# 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏
aborted_count += 1
Expand Down

0 comments on commit 2a56868

Please sign in to comment.