Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/mq_llm_engine/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind, SamplingParams

MODEL = "google/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
Expand Down Expand Up @@ -55,3 +56,49 @@ async def test_load(tmp_socket):

# Shutdown.
client.close()


@pytest.mark.asyncio
@pytest.mark.parametrize("min_chunk_size", [512, 64])
async def test_chunked_prefill(tmp_socket, min_chunk_size):
ENGINE_ARGS = AsyncEngineArgs(
model=MODEL,
disable_log_requests=True,
enable_chunked_prefill=True,
max_num_batched_tokens=512,
min_chunk_size=min_chunk_size,
)
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:

client = await engine.make_client()

large_request = "hello" * 1000
small_request = "hello"

async def generate(prompt, req_id):
async for out in client.generate(
request_id=req_id,
prompt=prompt,
sampling_params=SamplingParams(
max_tokens=1,
output_kind=RequestOutputKind.FINAL_ONLY),
):
return out

large_task = asyncio.create_task(generate(large_request, "one"))

small_task = asyncio.create_task(generate(small_request, "two"))

done, _ = await asyncio.wait((large_task, small_task),
return_when=asyncio.FIRST_COMPLETED)
for task in done:
if min_chunk_size == 512:
assert large_task in done
assert len(done) == 2
else:
assert small_task in done
assert len(done) == 1
assert task.exception() is None
# Shutdown.
client.close()
13 changes: 12 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,8 @@ def __init__(self,
num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False,
policy: str = "fcfs") -> None:
policy: str = "fcfs",
min_chunk_size: Optional[int] = None) -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
if num_scheduler_steps > 1:
Expand Down Expand Up @@ -1111,6 +1112,16 @@ def __init__(self,

self.max_num_batched_tokens = max_num_batched_tokens

if min_chunk_size is None:
min_chunk_size = self.max_num_batched_tokens
else:
assert min_chunk_size <= self.max_num_batched_tokens, \
f"Max chunk size {min_chunk_size} must be less than or equal to "
"the maximum number of batched tokens "
f"{self.max_num_batched_tokens}"

self.min_chunk_size = min_chunk_size

if enable_chunked_prefill:
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
Expand Down
121 changes: 79 additions & 42 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __init__(
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config

self.num_prefill_groups = 0
version = "selfattn"
if (self.scheduler_config.task == "embedding"
or self.cache_config.is_attention_free):
Expand Down Expand Up @@ -807,17 +807,17 @@ def _schedule_priority_preemption(
SequenceStatus.WAITING,
False, budget)

#Only preempt if priority inversion exists
# Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
# Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens and can_allocate == AllocStatus.OK
and budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
break

#Adjust budget to remove the victim sequence group
# Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens = self._get_num_new_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget)
Expand All @@ -827,11 +827,11 @@ def _schedule_priority_preemption(
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)

#Preempt out the victim sequence group
# Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)

waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
Expand Down Expand Up @@ -1216,6 +1216,8 @@ def schedule(
scheduler_start_time = time.perf_counter()

scheduler_outputs: SchedulerOutputs = self._schedule()
self.num_prefill_groups = scheduler_outputs.num_prefill_groups

now = time.time()

if not self.cache_config.enable_prefix_caching:
Expand Down Expand Up @@ -1583,6 +1585,44 @@ def _get_num_lookahead_slots(self, is_prefill: bool,

return self.scheduler_config.num_lookahead_slots

def _get_token_budget_for_request(self, budget: SchedulingBudget) -> int:
"""When doing chunked prefill, calculate the token budget for a single
chunk. This dynamically scales the chunk size down as the number of
sequences that require prefilling increases. This ensures that a single
sequence with a very large prompt to prefill doesn't take the entire
remaining token budget, allowing other sequences to prefill and decode
concurrently."""

# Get the current remaining token budget
remaining_token_budget = budget.remaining_token_budget()
if remaining_token_budget < self.scheduler_config.min_chunk_size:
# Skip all calculations if there's no way to reduce the budget any
# further anyway
return remaining_token_budget

# First get the number of sequences that require prefill
prefilling_seqs = self.num_prefill_groups
prefilling_seqs += len(self.waiting)

# Get the current remaining token budget
remaining_token_budget = budget.remaining_token_budget()

if prefilling_seqs == 0:
# Return immediately if there are no sequences that require prefill
return remaining_token_budget

# calculate a chunk size that shares it evenly across sequences that
# need to prefill
chunk_size = int(remaining_token_budget / prefilling_seqs)
# Ensure the chunk size is at least the minimum configured by the
# user, to limit the number of requests doing prefill
chunk_size = max(chunk_size, self.scheduler_config.min_chunk_size)
# And cap that at our actual budget so we don't spend tokens we
# don't have.
chunk_size = min(remaining_token_budget, chunk_size)
Comment on lines +1614 to +1622
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I feel in high QPS the overhead of this logic would be large. Specifically, when there are many prefill requests, you mostly would just allocate min_chunk_size, making this calculation not effective.


return chunk_size

def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> int:
Expand All @@ -1601,46 +1641,43 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in the given budget.

if self.scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > budget.remaining_token_budget() \
else num_new_tokens

# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
remaining_token_budget = budget.remaining_token_budget()
if self.scheduler_config.is_multi_step:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if num_new_tokens > self._get_prompt_limit(seq_group):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else:
num_new_tokens = 0 \
if num_new_tokens > remaining_token_budget \
else num_new_tokens
elif self.cache_config.enable_prefix_caching:
elif enable_chunking and len(seqs) == 1:
# Get the budget for this chunk
chunk_size = self._get_token_budget_for_request(budget=budget)

if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size = self.cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
# Set chunk size to the next lowest multiple of block size
# so we don't exceed our budget
chunk_size = (chunk_size // block_size) * block_size
# NB: In the case where num_new_tokens < chunk_size, this does
# not allocate a multiple of `block_size` tokens.

num_new_tokens = min(num_new_tokens, chunk_size)
return num_new_tokens
13 changes: 12 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
min_chunk_size: Optional[int] = None
max_num_seqs: int = 256
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
Expand Down Expand Up @@ -474,6 +475,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
'--min-chunk-size',
type=int,
default=EngineArgs.min_chunk_size,
help=
'For chunked prefill, the minimum number of tokens from a single '
'prompt to process in a single iteration. Must be less than or '
'equal to --max-num-batched-tokens.')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
Expand Down Expand Up @@ -1136,7 +1145,9 @@ def create_engine_config(self) -> VllmConfig:
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy)
policy=self.scheduling_policy,
min_chunk_size=self.min_chunk_size)

lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,9 @@ def get_logprobs(
if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
num_seq_groups = len(sampling_metadata.seq_groups)
return [empty_prompt_logprob
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups

selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None
Expand Down