@@ -95,14 +95,16 @@ class OutputData(NamedTuple):
95
95
96
96
class SchedulerContext :
97
97
98
- def __init__ (self ):
98
+ def __init__ (self , multi_step_stream_outputs : bool = False ):
99
99
self .output_queue : Deque [OutputData ] = deque ()
100
100
self .request_outputs : List [Union [RequestOutput ,
101
101
EmbeddingRequestOutput ]] = []
102
102
self .seq_group_metadata_list : Optional [
103
103
List [SequenceGroupMetadata ]] = None
104
104
self .scheduler_outputs : Optional [SchedulerOutputs ] = None
105
105
106
+ self .multi_step_stream_outputs : bool = multi_step_stream_outputs
107
+
106
108
def append_output (self , outputs : List [SamplerOutput ],
107
109
seq_group_metadata_list : List [SequenceGroupMetadata ],
108
110
scheduler_outputs : SchedulerOutputs , is_async : bool ,
@@ -219,6 +221,7 @@ def __init__(
219
221
usage_context : UsageContext = UsageContext .ENGINE_CONTEXT ,
220
222
stat_loggers : Optional [Dict [str , StatLoggerBase ]] = None ,
221
223
input_registry : InputRegistry = INPUT_REGISTRY ,
224
+ use_cached_outputs : bool = False ,
222
225
) -> None :
223
226
logger .info (
224
227
"Initializing an LLM engine (v%s) with config: "
@@ -234,8 +237,9 @@ def __init__(
234
237
"quantization_param_path=%s, device_config=%s, "
235
238
"decoding_config=%r, observability_config=%r, "
236
239
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
237
- "num_scheduler_steps=%d, enable_prefix_caching=%s, "
238
- "use_async_output_proc=%s, mm_processor_kwargs=%s)" ,
240
+ "num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
241
+ "enable_prefix_caching=%s, use_async_output_proc=%s, "
242
+ "use_cached_outputs=%s, mm_processor_kwargs=%s)" ,
239
243
VLLM_VERSION ,
240
244
model_config .model ,
241
245
speculative_config ,
@@ -266,8 +270,10 @@ def __init__(
266
270
model_config .served_model_name ,
267
271
scheduler_config .use_v2_block_manager ,
268
272
scheduler_config .num_scheduler_steps ,
273
+ scheduler_config .multi_step_stream_outputs ,
269
274
cache_config .enable_prefix_caching ,
270
275
model_config .use_async_output_proc ,
276
+ use_cached_outputs ,
271
277
model_config .mm_processor_kwargs ,
272
278
)
273
279
# TODO(woosuk): Print more configs in debug mode.
@@ -287,6 +293,7 @@ def __init__(
287
293
self .observability_config = observability_config or ObservabilityConfig (
288
294
)
289
295
self .log_stats = log_stats
296
+ self .use_cached_outputs = use_cached_outputs
290
297
291
298
if not self .model_config .skip_tokenizer_init :
292
299
self .tokenizer = self ._init_tokenizer ()
@@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
379
386
]
380
387
381
388
self .scheduler_contexts = [
382
- SchedulerContext ()
389
+ SchedulerContext (multi_step_stream_outputs = self .scheduler_config .
390
+ multi_step_stream_outputs )
383
391
for _ in range (self .parallel_config .pipeline_parallel_size )
384
392
]
385
393
@@ -998,7 +1006,8 @@ def _process_model_outputs(self,
998
1006
999
1007
seq_group = scheduled_seq_group .seq_group
1000
1008
seq_group .maybe_set_first_token_time (now )
1001
- request_output = RequestOutputFactory .create (seq_group )
1009
+ request_output = RequestOutputFactory .create (
1010
+ seq_group , use_cache = self .use_cached_outputs )
1002
1011
if request_output :
1003
1012
ctx .request_outputs .append (request_output )
1004
1013
@@ -1019,8 +1028,8 @@ def _process_model_outputs(self,
1019
1028
for scheduler in self .scheduler :
1020
1029
scheduler .free_finished_seq_groups ()
1021
1030
1022
- # For multi-step, do not create outputs each iteration
1023
- if not is_last_step :
1031
+ # For multi-step without streaming, don't create outputs each iteration
1032
+ if not is_last_step and not ctx . multi_step_stream_outputs :
1024
1033
# Immediately process request outputs here (if callback is given)
1025
1034
if (finished_now
1026
1035
and self .process_request_outputs_callback is not None ):
@@ -1037,17 +1046,27 @@ def _process_model_outputs(self,
1037
1046
1038
1047
seq_group = scheduled_seq_group .seq_group
1039
1048
seq_group .maybe_set_first_token_time (now )
1040
- request_output = RequestOutputFactory .create (seq_group )
1049
+ request_output = RequestOutputFactory .create (
1050
+ seq_group , use_cache = self .use_cached_outputs )
1041
1051
if request_output :
1042
1052
ctx .request_outputs .append (request_output )
1043
1053
1054
+ # For multi-step with streaming, create outputs each iteration
1055
+ if not is_last_step and ctx .multi_step_stream_outputs :
1056
+ # Immediately process request outputs here (if callback is given)
1057
+ if self .process_request_outputs_callback is not None :
1058
+ self .process_request_outputs_callback (ctx .request_outputs )
1059
+ ctx .request_outputs .clear ()
1060
+ return
1061
+
1044
1062
for seq_group in scheduler_outputs .ignored_seq_groups :
1045
1063
params = seq_group .sampling_params
1046
1064
if params is not None and params .output_kind == (
1047
1065
RequestOutputKind .DELTA ) and not seq_group .is_finished ():
1048
1066
continue
1049
1067
1050
- request_output = RequestOutputFactory .create (seq_group )
1068
+ request_output = RequestOutputFactory .create (
1069
+ seq_group , use_cache = self .use_cached_outputs )
1051
1070
if request_output :
1052
1071
ctx .request_outputs .append (request_output )
1053
1072
0 commit comments