22
22
from vllm .usage .usage_lib import UsageContext
23
23
from vllm .v1 .engine .core_client import EngineCoreClient
24
24
from vllm .v1 .engine .output_processor import OutputProcessor
25
- from vllm .v1 .engine .parallel_sampling import SyncParallelSamplingManager
25
+ from vllm .v1 .engine .parallel_sampling import ParentRequest
26
26
from vllm .v1 .engine .processor import Processor
27
27
from vllm .v1 .executor .abstract import Executor
28
28
@@ -50,9 +50,6 @@ def __init__(
50
50
self .model_config = vllm_config .model_config
51
51
self .cache_config = vllm_config .cache_config
52
52
53
- # Bookkeeping for parallel sampling requests
54
- self .parallel_manager = SyncParallelSamplingManager ()
55
-
56
53
# important: init dp group before init the engine_core
57
54
self .parallel_config = vllm_config .parallel_config
58
55
self .dp_enabled = self .parallel_config .data_parallel_size > 1 # noqa
@@ -120,8 +117,7 @@ def from_engine_args(
120
117
multiprocess_mode = enable_multiprocessing )
121
118
122
119
def get_num_unfinished_requests (self ) -> int :
123
- return self .parallel_manager .get_num_unfinished_requests (
124
- self .output_processor .get_num_unfinished_requests ())
120
+ return self .output_processor .get_num_unfinished_requests ()
125
121
126
122
def has_unfinished_requests (self ) -> bool :
127
123
has_unfinished = self .output_processor .has_unfinished_requests ()
@@ -157,48 +153,25 @@ def add_request(
157
153
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
158
154
priority : int = 0 ,
159
155
) -> None :
160
- """Add request."""
161
- kwargs = dict (request_id = request_id ,
162
- prompt = prompt ,
163
- params = params ,
164
- arrival_time = arrival_time ,
165
- lora_request = lora_request ,
166
- trace_headers = trace_headers ,
167
- prompt_adapter_request = prompt_adapter_request ,
168
- priority = priority )
169
- # Handle parallel sampling requests differently.
170
- if params is None or isinstance (params ,
171
- PoolingParams ) or params .n == 1 :
172
- self ._add_request (** kwargs )
173
- else :
174
- # Special handling for parallel sampling requests
175
- self .parallel_manager .add_request_parallel_sampling (
176
- add_request = self ._add_request , ** kwargs )
177
-
178
- def _add_request (
179
- self ,
180
- request_id : str ,
181
- prompt : PromptType ,
182
- params : Union [SamplingParams , PoolingParams ],
183
- arrival_time : Optional [float ] = None ,
184
- lora_request : Optional [LoRARequest ] = None ,
185
- trace_headers : Optional [Mapping [str , str ]] = None ,
186
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
187
- priority : int = 0 ,
188
- ) -> None :
189
- """Add request, `n=1`"""
190
- # 1) Process raw inputs into the request.
191
- request = self .processor .process_inputs (request_id , prompt , params ,
192
- arrival_time , lora_request ,
193
- trace_headers ,
194
- prompt_adapter_request ,
195
- priority )
196
-
197
- # 2) Make a new RequestState and queue.
198
- self .output_processor .add_request (request )
199
-
200
- # 3) Add the request to EngineCore.
201
- self .engine_core .add_request (request )
156
+ # 1) Fan out child requests (for n>1)
157
+ parent_req = ParentRequest .from_params (request_id , params )
158
+ n = params .n if isinstance (params , SamplingParams ) else 1
159
+ for idx in range (n ):
160
+ if parent_req is not None :
161
+ request_id , params = parent_req .get_child_info (idx )
162
+
163
+ # 2) Process raw inputs into the request.
164
+ request = self .processor .process_inputs (request_id , prompt , params ,
165
+ arrival_time , lora_request ,
166
+ trace_headers ,
167
+ prompt_adapter_request ,
168
+ priority )
169
+
170
+ # 3) Make a new RequestState and queue.
171
+ self .output_processor .add_request (request , parent_req , idx )
172
+
173
+ # 3) Add the request to EngineCore.
174
+ self .engine_core .add_request (request )
202
175
203
176
def step (self ) -> list [RequestOutput ]:
204
177
@@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]:
217
190
# 3) Abort any reqs that finished due to stop strings.
218
191
self .engine_core .abort_requests (processed_outputs .reqs_to_abort )
219
192
220
- request_outputs = processed_outputs .request_outputs
221
-
222
- # 4) Process unfinished parallel sampling requests
223
- return self .parallel_manager .step (request_outputs )
193
+ return processed_outputs .request_outputs
224
194
225
195
def get_model_config (self ):
226
196
return self .model_config
0 commit comments