Skip to content

Commit 4167252

Browse files
authored
[V1] Refactor parallel sampling support (vllm-project#13774)
Signed-off-by: Mark McLoughlin <[email protected]>
1 parent f35f8e2 commit 4167252

File tree

5 files changed

+201
-464
lines changed

5 files changed

+201
-464
lines changed

vllm/v1/engine/async_llm.py

+21-40
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.utils import cdiv, kill_process_tree
2626
from vllm.v1.engine.core_client import EngineCoreClient
2727
from vllm.v1.engine.output_processor import OutputProcessor
28-
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
28+
from vllm.v1.engine.parallel_sampling import ParentRequest
2929
from vllm.v1.engine.processor import Processor
3030
from vllm.v1.executor.abstract import Executor
3131
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
@@ -145,25 +145,30 @@ async def add_request(
145145
"""Add new request to the AsyncLLM."""
146146

147147
# 1) Create a new output queue for the request.
148-
if self.output_processor.is_request_active(request_id):
149-
raise ValueError(f"Request id {request_id} already running.")
150148
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
151149

152-
# 2) Convert Input --> Request.
153-
request = self.processor.process_inputs(request_id, prompt, params,
154-
arrival_time, lora_request,
155-
trace_headers,
156-
prompt_adapter_request,
157-
priority)
150+
# 2) Fan out child requests (for n>1)
151+
parent_req = ParentRequest.from_params(request_id, params)
152+
n = params.n if isinstance(params, SamplingParams) else 1
153+
for idx in range(n):
154+
if parent_req is not None:
155+
request_id, params = parent_req.get_child_info(idx)
158156

159-
# 3) Add the request to OutputProcessor (this process).
160-
self.output_processor.add_request(request, queue)
157+
# 3) Convert Input --> Request.
158+
request = self.processor.process_inputs(request_id, prompt, params,
159+
arrival_time, lora_request,
160+
trace_headers,
161+
prompt_adapter_request,
162+
priority)
161163

162-
# 4) Add the EngineCoreRequest to EngineCore (separate process).
163-
await self.engine_core.add_request_async(request)
164+
# 4) Add the request to OutputProcessor (this process).
165+
self.output_processor.add_request(request, parent_req, idx, queue)
164166

165-
if self.log_requests:
166-
logger.info("Added request %s.", request_id)
167+
# 5) Add the EngineCoreRequest to EngineCore (separate process).
168+
await self.engine_core.add_request_async(request)
169+
170+
if self.log_requests:
171+
logger.info("Added request %s.", request_id)
167172

168173
return queue
169174

@@ -172,7 +177,7 @@ async def add_request(
172177
# requests we don't need to send multiple messages to core proc,
173178
# and so we don't need multiple streams which then get
174179
# re-multiplexed in the API server anyhow.
175-
async def _generate(
180+
async def generate(
176181
self,
177182
prompt: PromptType,
178183
sampling_params: SamplingParams,
@@ -243,30 +248,6 @@ async def _generate(
243248
await self.abort(request_id)
244249
raise
245250

246-
def generate(
247-
self,
248-
prompt: PromptType,
249-
sampling_params: SamplingParams,
250-
request_id: str,
251-
lora_request: Optional[LoRARequest] = None,
252-
trace_headers: Optional[Mapping[str, str]] = None,
253-
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
254-
priority: int = 0,
255-
) -> AsyncGenerator[RequestOutput, None]:
256-
kwargs = dict(prompt=prompt,
257-
sampling_params=sampling_params,
258-
request_id=request_id,
259-
lora_request=lora_request,
260-
trace_headers=trace_headers,
261-
prompt_adapter_request=prompt_adapter_request,
262-
priority=priority)
263-
if sampling_params.n is None or sampling_params.n == 1:
264-
return self._generate(**kwargs)
265-
else:
266-
# Special handling for parallel sampling requests
267-
return generate_parallel_sampling_async(generate=self._generate,
268-
**kwargs)
269-
270251
async def _run_output_handler(self):
271252
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
272253

vllm/v1/engine/llm_engine.py

+22-52
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.usage.usage_lib import UsageContext
2323
from vllm.v1.engine.core_client import EngineCoreClient
2424
from vllm.v1.engine.output_processor import OutputProcessor
25-
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
25+
from vllm.v1.engine.parallel_sampling import ParentRequest
2626
from vllm.v1.engine.processor import Processor
2727
from vllm.v1.executor.abstract import Executor
2828

@@ -50,9 +50,6 @@ def __init__(
5050
self.model_config = vllm_config.model_config
5151
self.cache_config = vllm_config.cache_config
5252

53-
# Bookkeeping for parallel sampling requests
54-
self.parallel_manager = SyncParallelSamplingManager()
55-
5653
# important: init dp group before init the engine_core
5754
self.parallel_config = vllm_config.parallel_config
5855
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
@@ -120,8 +117,7 @@ def from_engine_args(
120117
multiprocess_mode=enable_multiprocessing)
121118

122119
def get_num_unfinished_requests(self) -> int:
123-
return self.parallel_manager.get_num_unfinished_requests(
124-
self.output_processor.get_num_unfinished_requests())
120+
return self.output_processor.get_num_unfinished_requests()
125121

126122
def has_unfinished_requests(self) -> bool:
127123
has_unfinished = self.output_processor.has_unfinished_requests()
@@ -157,48 +153,25 @@ def add_request(
157153
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
158154
priority: int = 0,
159155
) -> None:
160-
"""Add request."""
161-
kwargs = dict(request_id=request_id,
162-
prompt=prompt,
163-
params=params,
164-
arrival_time=arrival_time,
165-
lora_request=lora_request,
166-
trace_headers=trace_headers,
167-
prompt_adapter_request=prompt_adapter_request,
168-
priority=priority)
169-
# Handle parallel sampling requests differently.
170-
if params is None or isinstance(params,
171-
PoolingParams) or params.n == 1:
172-
self._add_request(**kwargs)
173-
else:
174-
# Special handling for parallel sampling requests
175-
self.parallel_manager.add_request_parallel_sampling(
176-
add_request=self._add_request, **kwargs)
177-
178-
def _add_request(
179-
self,
180-
request_id: str,
181-
prompt: PromptType,
182-
params: Union[SamplingParams, PoolingParams],
183-
arrival_time: Optional[float] = None,
184-
lora_request: Optional[LoRARequest] = None,
185-
trace_headers: Optional[Mapping[str, str]] = None,
186-
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
187-
priority: int = 0,
188-
) -> None:
189-
"""Add request, `n=1`"""
190-
# 1) Process raw inputs into the request.
191-
request = self.processor.process_inputs(request_id, prompt, params,
192-
arrival_time, lora_request,
193-
trace_headers,
194-
prompt_adapter_request,
195-
priority)
196-
197-
# 2) Make a new RequestState and queue.
198-
self.output_processor.add_request(request)
199-
200-
# 3) Add the request to EngineCore.
201-
self.engine_core.add_request(request)
156+
# 1) Fan out child requests (for n>1)
157+
parent_req = ParentRequest.from_params(request_id, params)
158+
n = params.n if isinstance(params, SamplingParams) else 1
159+
for idx in range(n):
160+
if parent_req is not None:
161+
request_id, params = parent_req.get_child_info(idx)
162+
163+
# 2) Process raw inputs into the request.
164+
request = self.processor.process_inputs(request_id, prompt, params,
165+
arrival_time, lora_request,
166+
trace_headers,
167+
prompt_adapter_request,
168+
priority)
169+
170+
# 3) Make a new RequestState and queue.
171+
self.output_processor.add_request(request, parent_req, idx)
172+
173+
# 3) Add the request to EngineCore.
174+
self.engine_core.add_request(request)
202175

203176
def step(self) -> list[RequestOutput]:
204177

@@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]:
217190
# 3) Abort any reqs that finished due to stop strings.
218191
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
219192

220-
request_outputs = processed_outputs.request_outputs
221-
222-
# 4) Process unfinished parallel sampling requests
223-
return self.parallel_manager.step(request_outputs)
193+
return processed_outputs.request_outputs
224194

225195
def get_model_config(self):
226196
return self.model_config

0 commit comments

Comments
 (0)