Skip to content
Merged
4 changes: 2 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ class EngineOutput:

Args:
status (ResponseType): the response type.
token_ids (List[int]): the output token ids.
num_token (int): the number of output tokens, which is equal to `len(token_ids)`
token_ids (List[int]): the newly generated token ids in each iteration.
num_token (int): the newly generated token number, equal to `len(token_ids)`
logprobs (List[Dict[int, float]]): the top logprobs for each output
position.
cache_block_ids (List[int]): send cache blocks back for migration in
Expand Down
5 changes: 2 additions & 3 deletions lmdeploy/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
outputs (EngineOutput): The output from the engine containing information about the current iteration.
req_state (RequestState): The state of the request, including timestamps and token counts.
"""
new_generation_tokens = outputs.num_token - req_state.generation_tokens
new_generation_tokens = outputs.num_token
if new_generation_tokens == 0:
return
self.new_generation_tokens = new_generation_tokens
Expand All @@ -213,9 +213,8 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
# update the latest token generation time
req_state.lastest_token_time = outputs.req_metrics.token_timestamp
# update the number of generated tokens
req_state.generation_tokens = outputs.num_token
req_state.generation_tokens += outputs.num_token

if outputs.status != ResponseType.SUCCESS:
req_state.finish_reason = outputs.status
req_state.finish_time = self.iteration_timestamp
req_state.generation_tokens = outputs.num_token
12 changes: 7 additions & 5 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,31 +144,33 @@ async def async_stream_infer(self,
)
logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.')
resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
output_offset = 0

while True:
resp = await self.req_sender.async_recv(resp)

cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
logprobs = resp.data.get('logprobs', None) if resp.data else None
logprobs = resp.data.pop('logprobs', None) if resp.data else None
if resp.type == ResponseType.SUCCESS:
token_ids = resp.data['token_ids'].tolist()
num_ids = len(token_ids)
num_ids = len(token_ids) - output_offset
logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.')
yield EngineOutput(resp.type,
token_ids,
token_ids[output_offset:],
num_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=logprobs)
output_offset = len(token_ids)
elif resp.type == ResponseType.FINISH:
resp_data = resp.data
token_ids = resp_data['token_ids'].tolist()
logits = resp_data['logits']
num_ids = len(token_ids)
num_ids = len(token_ids) - output_offset
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
yield EngineOutput(resp.type,
token_ids,
token_ids[output_offset:],
num_ids,
logits=logits,
cache_block_ids=cache_block_ids,
Expand Down
29 changes: 29 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, List, Optional

from lmdeploy.messages import EngineOutput
from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest,
DistServeInitRequest)
from lmdeploy.utils import get_logger
Expand Down Expand Up @@ -127,3 +128,31 @@ async def instance_async_stream_infer(self, *args, **kwargs):
"""Send stream inference request."""
async for result in self.instance_pool.async_stream_infer(*args, **kwargs):
yield result


class EngineOutputGather:
"""Helper class to gather incremental engine output."""

def __init__(self):
self._output = dict()

def get(self, stream_id):
if stream_id not in self._output:
self._output[stream_id] = EngineOutput(status=None, token_ids=[], num_token=0, logprobs=[])
return self._output[stream_id]

def add(self, stream_id, result):
if not isinstance(result, EngineOutput):
return
output = self.get(stream_id)
output.token_ids.extend(result.token_ids or [])
output.logprobs.extend(result.logprobs or [])

def pop(self, stream_id, result):
if not isinstance(result, EngineOutput):
return result
output = self._output.pop(stream_id)
result.token_ids = output.token_ids or []
result.logprobs = output.logprobs or None
result.num_token = len(output.token_ids)
return result
6 changes: 5 additions & 1 deletion lmdeploy/pytorch/engine/mp_engine/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lmdeploy.utils import get_logger

from .base import MPEngine
from .base_worker import EngineWorkerBase
from .base_worker import EngineOutputGather, EngineWorkerBase

logger = get_logger('lmdeploy')

Expand All @@ -35,12 +35,14 @@ def __init__(self,
self._stream_id = 0
self._stream_aiter = dict()
self._stream_task = dict()
self._engine_output_gather = EngineOutputGather()

async def _stream_task_wrapper(self, stream_id: int, func: str, *args, **kwargs):
"""Create a stream task."""
method = getattr(self, func)
event = self._stream_aiter[stream_id][0]
async for result in method(*args, **kwargs):
self._engine_output_gather.add(stream_id, result)
self._stream_aiter[stream_id][1] = (result, False)
event.set()
self._stream_aiter[stream_id][1] = (result, True)
Expand All @@ -67,6 +69,8 @@ async def get_stream_task_result(self, stream_id: int):
result, stopped = self._stream_aiter[stream_id][1]
event.clear()

result = self._engine_output_gather.pop(stream_id, result)

if stopped:
self._stream_aiter.pop(stream_id, None)
self._stream_task.pop(stream_id, None)
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from lmdeploy.utils import get_logger

from .base_worker import EngineOutputGather

logger = get_logger('lmdeploy')


Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(self):
# streaming
self.stream_output = dict()
self._stream_idx = 0
self._engine_output_gather = EngineOutputGather()

def get_port(self):
return self.port
Expand Down Expand Up @@ -98,6 +101,7 @@ async def _method_async_streaming_task(self, stream_id, method: Callable, args:
try:
generator = method(*args, **kwargs)
async for result in generator:
self._engine_output_gather.add(stream_id, result)
stream_out['result'] = result
stream_out['event'].set()
except Exception as e:
Expand All @@ -116,6 +120,7 @@ async def get_stream_output(self, stream_id: int):
event.clear()
result = stream_out['result']
stopped = stream_out['stopped']
result = self._engine_output_gather.pop(stream_id, result)
if stopped:
self.stream_output.pop(stream_id)
if 'error' in stream_out:
Expand Down
29 changes: 6 additions & 23 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,6 @@ def is_error(status):
input_len = len(input_ids)
output_len, gen_len = 0, 0
state = DetokenizeState(len(input_ids))
start_ids_offset = state.ids_offset
response = ''
finish_reason = None
async with self.safe_run(inst,
Expand All @@ -846,7 +845,6 @@ def is_error(status):
sequence_start=sequence_start,
sequence_end=sequence_end,
step=history_len) as gen:
prev_len = 0
hit_stop_token = 0
req_state = RequestState(prompt_tokens=input_len) # per-requst state
async for outputs in gen:
Expand All @@ -857,23 +855,16 @@ def is_error(status):
break

output_len = outputs.num_token

if hit_stop_token or prev_len == output_len:
if hit_stop_token:
continue

# This assumes the engine will stop when stop token is hit
if output_len and outputs.token_ids[-1] in stop_ids:
hit_stop_token = 1
# one token and it's been skipped
if output_len == prev_len + 1:
continue

mask = slice(prev_len - output_len, output_len - hit_stop_token)
token_ids += outputs.token_ids[mask]
token_ids += outputs.token_ids
gen_len = len(token_ids) - input_len

prev_len = output_len

ids_offset = state.ids_offset
response, state = self.tokenizer.detokenize_incrementally(
token_ids,
Expand All @@ -889,21 +880,13 @@ def is_error(status):
finish_reason,
token_ids=res,
cache_block_ids=outputs.cache_block_ids)

if outputs.logprobs is not None:
log_offset = ids_offset - start_ids_offset
out.logprobs = outputs.logprobs[log_offset:]
if hit_stop_token:
out.logprobs = out.logprobs[:-hit_stop_token]
out.logprobs = (outputs.logprobs[:-hit_stop_token] if hit_stop_token else outputs.logprobs)
if outputs.last_hidden_state is not None:
out.last_hidden_state = outputs.last_hidden_state
if hit_stop_token:
out.last_hidden_state = out.last_hidden_state[:-hit_stop_token]
out.last_hidden_state = (outputs.last_hidden_state[:-hit_stop_token]
if hit_stop_token else outputs.last_hidden_state)
if outputs.logits is not None:
out.logits = outputs.logits
if hit_stop_token:
out.logits = out.logits[:-hit_stop_token]

out.logits = (outputs.logits[:-hit_stop_token] if hit_stop_token else outputs.logits)
yield out
# end of generator loop
metrics_processor.increment_finished_requests()
Expand Down
50 changes: 29 additions & 21 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,24 @@ def _func(out: EngineOutput, step: int, **kwargs):
return _func


def _get_logprobs_impl(logprob_vals: torch.Tensor,
logprob_idxs: torch.Tensor,
logprob_nums: torch.Tensor,
output_ids: List[int],
logprobs: int,
out_logprobs: List[Dict[int, float]] = None):
length = len(output_ids)
offset = len(out_logprobs)
if length == offset:
return out_logprobs
for (pos, idx, val, n) in zip(range(offset, length), logprob_idxs[offset:length], logprob_vals[offset:length],
def _get_logprobs_impl(logprob_vals: torch.Tensor, logprob_idxs: torch.Tensor, logprob_nums: torch.Tensor,
output_ids: List[int], logprobs: int, offset: int):
"""Get logprob of each generated token.

Args:
logprob_vals (torch.Tensor): shape (max_new_tokens, 1024),
1024 is the max_logprobs that turbomind engine can output
logprob_idxs (torch.Tensor): shape (max_new_tokens, 1024)
logprob_nums (torch.Tensor): shape (max_new_tokens,)
output_ids (List[int]): new generated token ids
logprobs (int): top n logprobs to return
offset (int): offset to index logprob_vals, logprob_idxs and logprob_nums.
It indicates where to start getting logprobs for the current generated tokens `output_ids`
"""
out_logprobs = []
# the total generated token number until now
length = len(output_ids) + offset
for (pos, idx, val, n) in zip(range(len(output_ids)), logprob_idxs[offset:length], logprob_vals[offset:length],
logprob_nums[offset:length]):
topn = min(n.item(), logprobs)
tok_res = {idx[i].item(): val[i].item() for i in range(topn)}
Expand All @@ -447,15 +454,16 @@ def _get_logprobs_impl(logprob_vals: torch.Tensor,


def _get_logprobs(outputs, output_logprobs: int):
logprob_vals = outputs['logprob_vals']
logprob_idxs = outputs['logprob_indexes']
logprob_nums = outputs['logprob_nums']

logprobs = []
logprob_vals = outputs['logprob_vals'] # shape {max_new_tokens, 1024}
logprob_idxs = outputs['logprob_indexes'] # shape {max_new_tokens, 1024}
logprob_nums = outputs['logprob_nums'] # shape {max_new_tokens,}
offset = 0 # offset to index logprob_vals, logprob_idxs and logprob_nums

def _func(out: EngineOutput, step: int, **kwargs):
_get_logprobs_impl(logprob_vals, logprob_idxs, logprob_nums, out.token_ids, output_logprobs, logprobs)
out.logprobs = logprobs
nonlocal offset
out.logprobs = _get_logprobs_impl(logprob_vals, logprob_idxs, logprob_nums, out.token_ids, output_logprobs,
offset)
offset += len(out.token_ids)

return _func

Expand All @@ -467,7 +475,7 @@ def _get_metrics(metrics):

is_first = True

def _func(out: EngineOutput, step: int, is_first_token: bool = False, **kwargs):
def _func(out: EngineOutput, step: int, **kwargs):
nonlocal is_first
if not is_first:
out.req_metrics = RequestMetrics(token_timestamp=time.time())
Expand Down Expand Up @@ -773,8 +781,8 @@ async def async_stream_infer(self,
if seq_len == prev_len and not finish:
continue

output_ids += output_ids_buf[prev_len:seq_len].tolist()
output_len += seq_len - prev_len
output_ids = output_ids_buf[prev_len:seq_len].tolist()
output_len = seq_len - prev_len
output = EngineOutput(ret_status, output_ids, output_len)

for f in extra_fs:
Expand Down