Skip to content

Commit e8a55a2

Browse files
alexm-redhatsumitd2
authored andcommitted
Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (vllm-project#8335)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent 9b975db commit e8a55a2

File tree

7 files changed

+142
-42
lines changed

7 files changed

+142
-42
lines changed

tests/entrypoints/openai/test_accuracy.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
RTOL = 0.03
2020
EXPECTED_VALUE = 0.58
2121
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
22-
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
22+
MORE_ARGS_LIST = [
23+
["--enable-chunked-prefill"], # Chunked
24+
["--num-scheduler-steps", "8"], # MS
25+
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
26+
]
2327

2428

2529
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

vllm/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def __init__(self,
960960
is_multimodal_model: bool = False,
961961
preemption_mode: Optional[str] = None,
962962
num_scheduler_steps: int = 1,
963+
multi_step_stream_outputs: bool = False,
963964
send_delta_data: bool = False) -> None:
964965
if max_num_batched_tokens is None:
965966
if enable_chunked_prefill:
@@ -1000,6 +1001,7 @@ def __init__(self,
10001001
self.embedding_mode = embedding_mode
10011002
self.preemption_mode = preemption_mode
10021003
self.num_scheduler_steps = num_scheduler_steps
1004+
self.multi_step_stream_outputs = multi_step_stream_outputs
10031005
self.send_delta_data = send_delta_data
10041006
self._verify_args()
10051007

vllm/engine/arg_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class EngineArgs:
145145
max_cpu_loras: Optional[int] = None
146146
device: str = 'auto'
147147
num_scheduler_steps: int = 1
148+
multi_step_stream_outputs: bool = False
148149
ray_workers_use_nsight: bool = False
149150
num_gpu_blocks_override: Optional[int] = None
150151
num_lookahead_slots: int = 0
@@ -595,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
595596
help=('Maximum number of forward steps per '
596597
'scheduler call.'))
597598

599+
parser.add_argument(
600+
'--multi-step-stream-outputs',
601+
action='store_true',
602+
help='If True, then multi-step will stream outputs for every step')
598603
parser.add_argument(
599604
'--scheduler-delay-factor',
600605
type=float,
@@ -999,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig:
9991004
is_multimodal_model=model_config.is_multimodal_model,
10001005
preemption_mode=self.preemption_mode,
10011006
num_scheduler_steps=self.num_scheduler_steps,
1007+
multi_step_stream_outputs=self.multi_step_stream_outputs,
10021008
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
10031009
and parallel_config.use_ray),
10041010
)

vllm/engine/llm_engine.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,16 @@ class OutputData(NamedTuple):
9595

9696
class SchedulerContext:
9797

98-
def __init__(self):
98+
def __init__(self, multi_step_stream_outputs: bool = False):
9999
self.output_queue: Deque[OutputData] = deque()
100100
self.request_outputs: List[Union[RequestOutput,
101101
EmbeddingRequestOutput]] = []
102102
self.seq_group_metadata_list: Optional[
103103
List[SequenceGroupMetadata]] = None
104104
self.scheduler_outputs: Optional[SchedulerOutputs] = None
105105

106+
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
107+
106108
def append_output(self, outputs: List[SamplerOutput],
107109
seq_group_metadata_list: List[SequenceGroupMetadata],
108110
scheduler_outputs: SchedulerOutputs, is_async: bool,
@@ -219,6 +221,7 @@ def __init__(
219221
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
220222
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
221223
input_registry: InputRegistry = INPUT_REGISTRY,
224+
use_cached_outputs: bool = False,
222225
) -> None:
223226
logger.info(
224227
"Initializing an LLM engine (v%s) with config: "
@@ -234,8 +237,9 @@ def __init__(
234237
"quantization_param_path=%s, device_config=%s, "
235238
"decoding_config=%r, observability_config=%r, "
236239
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
237-
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
238-
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
240+
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
241+
"enable_prefix_caching=%s, use_async_output_proc=%s, "
242+
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
239243
VLLM_VERSION,
240244
model_config.model,
241245
speculative_config,
@@ -266,8 +270,10 @@ def __init__(
266270
model_config.served_model_name,
267271
scheduler_config.use_v2_block_manager,
268272
scheduler_config.num_scheduler_steps,
273+
scheduler_config.multi_step_stream_outputs,
269274
cache_config.enable_prefix_caching,
270275
model_config.use_async_output_proc,
276+
use_cached_outputs,
271277
model_config.mm_processor_kwargs,
272278
)
273279
# TODO(woosuk): Print more configs in debug mode.
@@ -287,6 +293,7 @@ def __init__(
287293
self.observability_config = observability_config or ObservabilityConfig(
288294
)
289295
self.log_stats = log_stats
296+
self.use_cached_outputs = use_cached_outputs
290297

291298
if not self.model_config.skip_tokenizer_init:
292299
self.tokenizer = self._init_tokenizer()
@@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
379386
]
380387

381388
self.scheduler_contexts = [
382-
SchedulerContext()
389+
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
390+
multi_step_stream_outputs)
383391
for _ in range(self.parallel_config.pipeline_parallel_size)
384392
]
385393

@@ -998,7 +1006,8 @@ def _process_model_outputs(self,
9981006

9991007
seq_group = scheduled_seq_group.seq_group
10001008
seq_group.maybe_set_first_token_time(now)
1001-
request_output = RequestOutputFactory.create(seq_group)
1009+
request_output = RequestOutputFactory.create(
1010+
seq_group, use_cache=self.use_cached_outputs)
10021011
if request_output:
10031012
ctx.request_outputs.append(request_output)
10041013

@@ -1019,8 +1028,8 @@ def _process_model_outputs(self,
10191028
for scheduler in self.scheduler:
10201029
scheduler.free_finished_seq_groups()
10211030

1022-
# For multi-step, do not create outputs each iteration
1023-
if not is_last_step:
1031+
# For multi-step without streaming, don't create outputs each iteration
1032+
if not is_last_step and not ctx.multi_step_stream_outputs:
10241033
# Immediately process request outputs here (if callback is given)
10251034
if (finished_now
10261035
and self.process_request_outputs_callback is not None):
@@ -1037,17 +1046,27 @@ def _process_model_outputs(self,
10371046

10381047
seq_group = scheduled_seq_group.seq_group
10391048
seq_group.maybe_set_first_token_time(now)
1040-
request_output = RequestOutputFactory.create(seq_group)
1049+
request_output = RequestOutputFactory.create(
1050+
seq_group, use_cache=self.use_cached_outputs)
10411051
if request_output:
10421052
ctx.request_outputs.append(request_output)
10431053

1054+
# For multi-step with streaming, create outputs each iteration
1055+
if not is_last_step and ctx.multi_step_stream_outputs:
1056+
# Immediately process request outputs here (if callback is given)
1057+
if self.process_request_outputs_callback is not None:
1058+
self.process_request_outputs_callback(ctx.request_outputs)
1059+
ctx.request_outputs.clear()
1060+
return
1061+
10441062
for seq_group in scheduler_outputs.ignored_seq_groups:
10451063
params = seq_group.sampling_params
10461064
if params is not None and params.output_kind == (
10471065
RequestOutputKind.DELTA) and not seq_group.is_finished():
10481066
continue
10491067

1050-
request_output = RequestOutputFactory.create(seq_group)
1068+
request_output = RequestOutputFactory.create(
1069+
seq_group, use_cache=self.use_cached_outputs)
10511070
if request_output:
10521071
ctx.request_outputs.append(request_output)
10531072

vllm/engine/multiprocessing/engine.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,14 @@ def __init__(self,
6666
*args,
6767
log_requests: bool = True,
6868
**kwargs) -> None:
69-
self.engine = LLMEngine(*args, **kwargs)
69+
# For MQLLMEngine, we can use cached outputs, since each new request
70+
# output is immediately pickled and send over the socket, which frees
71+
# the python object to be reused again.
72+
use_cached_outputs = True
73+
74+
self.engine = LLMEngine(*args,
75+
**kwargs,
76+
use_cached_outputs=use_cached_outputs)
7077
self.log_requests = log_requests
7178

7279
self.use_async_sockets = use_async_sockets

vllm/outputs.py

+74-22
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,28 @@ def __init__(
114114
self.encoder_prompt_token_ids = encoder_prompt_token_ids
115115

116116
@classmethod
117-
def from_seq_group(cls,
118-
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
117+
def from_seq_group(cls, seq_group: SequenceGroup,
118+
use_cache: bool) -> Optional["RequestOutput"]:
119119
sampling_params = seq_group.sampling_params
120120
if sampling_params is None:
121121
raise ValueError(
122122
"Sampling parameters are missing for a CompletionRequest.")
123+
123124
finished = seq_group.is_finished()
124125
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
125126
not finished):
126127
return None
127128

129+
# Init cache (if needed)
130+
if use_cache and seq_group.cached_request_output is None:
131+
seq_group.cached_request_output = RequestOutput( # type: ignore
132+
request_id="",
133+
prompt=None,
134+
prompt_token_ids=[],
135+
prompt_logprobs=None,
136+
outputs=[],
137+
finished=False)
138+
128139
seqs = seq_group.get_seqs()
129140
if len(seqs) == 1:
130141
top_n_seqs = seqs
@@ -149,29 +160,66 @@ def from_seq_group(cls,
149160

150161
outputs = []
151162
include_prompt = True
152-
for seq in top_n_seqs:
163+
for i, seq in enumerate(top_n_seqs):
153164
output_text = seq.get_output_text_to_return(
154165
text_buffer_length, delta)
166+
155167
output_token_ids = seq.get_output_token_ids_to_return(delta)
168+
num_output_tokens = 1 if isinstance(output_token_ids,
169+
int) else len(output_token_ids)
170+
156171
output_logprobs = seq.output_logprobs if include_logprobs else None
157172

158173
if delta:
159174
# Slice logprobs delta if applicable
160175
if output_logprobs:
161-
output_logprobs = output_logprobs[-len(output_token_ids):]
176+
output_logprobs = output_logprobs[-num_output_tokens:]
162177
# Don't include prompt if this is after the first output
163178
# containing decode token ids
164-
if include_prompt and seq.get_output_len() > len(
165-
output_token_ids):
179+
if include_prompt and seq.get_output_len() > num_output_tokens:
166180
include_prompt = False
167181

168-
outputs.append(
169-
CompletionOutput(
170-
seqs.index(seq), output_text, output_token_ids,
182+
if use_cache:
183+
# Get cached output object
184+
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
185+
if i >= len(cached_outputs):
186+
cached_outputs.append(
187+
CompletionOutput(index=i,
188+
text="",
189+
token_ids=[],
190+
cumulative_logprob=None,
191+
logprobs=None,
192+
finish_reason=None,
193+
stop_reason=None))
194+
output = cached_outputs[i]
195+
196+
# Init cached output object
197+
assert output.index == i
198+
output.text = output_text
199+
200+
if isinstance(output_token_ids, int):
201+
output.token_ids.clear()
202+
output.token_ids.append(output_token_ids)
203+
else:
204+
output.token_ids = output_token_ids
205+
206+
output.cumulative_logprob = seq.get_cumulative_logprob() \
207+
if include_logprobs else None
208+
output.logprobs = output_logprobs
209+
output.finish_reason = SequenceStatus.get_finished_reason(
210+
seq.status)
211+
output.stop_reason = seq.stop_reason
212+
213+
else:
214+
output = CompletionOutput(
215+
seqs.index(seq), output_text, [output_token_ids]
216+
if isinstance(output_token_ids, int) else output_token_ids,
171217
seq.get_cumulative_logprob() if include_logprobs else None,
172218
output_logprobs,
173219
SequenceStatus.get_finished_reason(seq.status),
174-
seq.stop_reason))
220+
seq.stop_reason)
221+
222+
outputs.append(output)
175223

176224
# Every sequence in the sequence group should have the same prompt.
177225
if include_prompt:
@@ -188,16 +236,20 @@ def from_seq_group(cls,
188236
prompt_logprobs = None
189237
finished_time = time.time() if finished else None
190238
seq_group.set_finished_time(finished_time)
191-
return cls(seq_group.request_id,
192-
prompt,
193-
prompt_token_ids,
194-
prompt_logprobs,
195-
outputs,
196-
finished,
197-
seq_group.metrics,
198-
lora_request=seq_group.lora_request,
199-
encoder_prompt=encoder_prompt,
200-
encoder_prompt_token_ids=encoder_prompt_token_ids)
239+
240+
init_args = (seq_group.request_id, prompt, prompt_token_ids,
241+
prompt_logprobs, outputs, finished, seq_group.metrics,
242+
seq_group.lora_request, encoder_prompt,
243+
encoder_prompt_token_ids)
244+
245+
if use_cache:
246+
request_output = seq_group.cached_request_output
247+
request_output.__init__(*init_args) # type: ignore
248+
249+
else:
250+
request_output = cls(*init_args)
251+
252+
return request_output
201253

202254
def __repr__(self) -> str:
203255
return (f"RequestOutput(request_id={self.request_id}, "
@@ -261,10 +313,10 @@ def __repr__(self):
261313
class RequestOutputFactory:
262314

263315
@staticmethod
264-
def create(seq_group):
316+
def create(seq_group: SequenceGroup, use_cache: bool = False):
265317
# Determine the type based on a condition, for example:
266318
if hasattr(seq_group,
267319
'embeddings') and seq_group.embeddings is not None:
268320
return EmbeddingRequestOutput.from_seq_group(seq_group)
269321
else:
270-
return RequestOutput.from_seq_group(seq_group)
322+
return RequestOutput.from_seq_group(seq_group, use_cache)

vllm/sequence.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def __init__(
436436
self.stop_reason: Union[int, str, None] = None
437437

438438
# These are used to keep track of delta outputs
439-
self._last_token_ids_offset: int = 0
439+
self._last_output_token_ids_offset: int = 0
440440
self._last_output_text_offset: int = 0
441441

442442
# Used for incremental detokenization
@@ -499,18 +499,26 @@ def get_output_text_to_return(self, buffer_length: int,
499499
return self.output_text[last_offset:length]
500500
return ""
501501

502-
def get_output_token_ids_to_return(self,
503-
delta: bool) -> GenericSequence[int]:
502+
def get_output_token_ids_to_return(
503+
self, delta: bool) -> Union[GenericSequence[int], int]:
504504
"""If delta is True, only new tokens since the last call to
505505
this method are returned"""
506506
if not delta:
507507
return self.get_output_token_ids()
508-
length = self.get_output_len()
509-
last_offset = self._last_token_ids_offset
510-
if last_offset < length:
511-
self._last_token_ids_offset = length
512-
return self.data._output_token_ids[last_offset:]
513-
return ()
508+
509+
output_len = self.get_output_len()
510+
511+
# Get the number of new tokens
512+
num_new_tokens = output_len - self._last_output_token_ids_offset
513+
self._last_output_token_ids_offset = output_len
514+
515+
# Return new tokens
516+
if num_new_tokens == 1:
517+
# Optimization for single decode token case
518+
# (which is what we have most of the time)
519+
return self.data._cached_all_token_ids[-1]
520+
521+
return self.data._cached_all_token_ids[-num_new_tokens:]
514522

515523
def hash_of_block(self, logical_idx: int) -> int:
516524
# TODO This can produce incorrect hash when block size > prompt size
@@ -671,6 +679,8 @@ def __init__(
671679
self.encoder_seq = encoder_seq
672680
self.trace_headers = trace_headers
673681

682+
self.cached_request_output = None
683+
674684
@property
675685
def prompt(self) -> Optional[str]:
676686
# All sequences in the group should have the same prompt.

0 commit comments

Comments
 (0)