diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py index 630c112d0f0c..e9d285f09c71 100644 --- a/tests/mq_llm_engine/test_load.py +++ b/tests/mq_llm_engine/test_load.py @@ -8,6 +8,7 @@ from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind, SamplingParams MODEL = "google/gemma-1.1-2b-it" NUM_EXPECTED_TOKENS = 10 @@ -55,3 +56,49 @@ async def test_load(tmp_socket): # Shutdown. client.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("min_chunk_size", [512, 64]) +async def test_chunked_prefill(tmp_socket, min_chunk_size): + ENGINE_ARGS = AsyncEngineArgs( + model=MODEL, + disable_log_requests=True, + enable_chunked_prefill=True, + max_num_batched_tokens=512, + min_chunk_size=min_chunk_size, + ) + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + large_request = "hello" * 1000 + small_request = "hello" + + async def generate(prompt, req_id): + async for out in client.generate( + request_id=req_id, + prompt=prompt, + sampling_params=SamplingParams( + max_tokens=1, + output_kind=RequestOutputKind.FINAL_ONLY), + ): + return out + + large_task = asyncio.create_task(generate(large_request, "one")) + + small_task = asyncio.create_task(generate(small_request, "two")) + + done, _ = await asyncio.wait((large_task, small_task), + return_when=asyncio.FIRST_COMPLETED) + for task in done: + if min_chunk_size == 512: + assert large_task in done + assert len(done) == 2 + else: + assert small_task in done + assert len(done) == 1 + assert task.exception() is None + # Shutdown. + client.close() diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b..a4a52787ca58 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1078,7 +1078,8 @@ def __init__(self, num_scheduler_steps: int = 1, multi_step_stream_outputs: bool = False, send_delta_data: bool = False, - policy: str = "fcfs") -> None: + policy: str = "fcfs", + min_chunk_size: Optional[int] = None) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: if num_scheduler_steps > 1: @@ -1111,6 +1112,16 @@ def __init__(self, self.max_num_batched_tokens = max_num_batched_tokens + if min_chunk_size is None: + min_chunk_size = self.max_num_batched_tokens + else: + assert min_chunk_size <= self.max_num_batched_tokens, \ + f"Max chunk size {min_chunk_size} must be less than or equal to " + "the maximum number of batched tokens " + f"{self.max_num_batched_tokens}" + + self.min_chunk_size = min_chunk_size + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e56d5cddce42..fec5489ec8f9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -311,7 +311,7 @@ def __init__( # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. self.lora_config = lora_config - + self.num_prefill_groups = 0 version = "selfattn" if (self.scheduler_config.task == "embedding" or self.cache_config.is_attention_free): @@ -807,17 +807,17 @@ def _schedule_priority_preemption( SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs)): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens = self._get_num_new_tokens( vseq_group, SequenceStatus.RUNNING, False, budget) @@ -827,11 +827,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -1216,6 +1216,8 @@ def schedule( scheduler_start_time = time.perf_counter() scheduler_outputs: SchedulerOutputs = self._schedule() + self.num_prefill_groups = scheduler_outputs.num_prefill_groups + now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1583,6 +1585,44 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots + def _get_token_budget_for_request(self, budget: SchedulingBudget) -> int: + """When doing chunked prefill, calculate the token budget for a single + chunk. This dynamically scales the chunk size down as the number of + sequences that require prefilling increases. This ensures that a single + sequence with a very large prompt to prefill doesn't take the entire + remaining token budget, allowing other sequences to prefill and decode + concurrently.""" + + # Get the current remaining token budget + remaining_token_budget = budget.remaining_token_budget() + if remaining_token_budget < self.scheduler_config.min_chunk_size: + # Skip all calculations if there's no way to reduce the budget any + # further anyway + return remaining_token_budget + + # First get the number of sequences that require prefill + prefilling_seqs = self.num_prefill_groups + prefilling_seqs += len(self.waiting) + + # Get the current remaining token budget + remaining_token_budget = budget.remaining_token_budget() + + if prefilling_seqs == 0: + # Return immediately if there are no sequences that require prefill + return remaining_token_budget + + # calculate a chunk size that shares it evenly across sequences that + # need to prefill + chunk_size = int(remaining_token_budget / prefilling_seqs) + # Ensure the chunk size is at least the minimum configured by the + # user, to limit the number of requests doing prefill + chunk_size = max(chunk_size, self.scheduler_config.min_chunk_size) + # And cap that at our actual budget so we don't spend tokens we + # don't have. + chunk_size = min(remaining_token_budget, chunk_size) + + return chunk_size + def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, budget: SchedulingBudget) -> int: @@ -1601,46 +1641,43 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. + + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > budget.remaining_token_budget() \ + else num_new_tokens + # If number of seq > 1, it means it is doing beam search # in a decode phase. Do not chunk. - if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: + elif enable_chunking and len(seqs) == 1: + # Get the budget for this chunk + chunk_size = self._get_token_budget_for_request(budget=budget) + + if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate # the number of new tokens that is dividable by the block # size to avoid partial block matching. block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Set chunk size to the next lowest multiple of block size + # so we don't exceed our budget + chunk_size = (chunk_size // block_size) * block_size + # NB: In the case where num_new_tokens < chunk_size, this does + # not allocate a multiple of `block_size` tokens. + + num_new_tokens = min(num_new_tokens, chunk_size) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b556c0eed377..643bc5f567f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -122,6 +122,7 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None + min_chunk_size: Optional[int] = None max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -474,6 +475,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_num_batched_tokens, help='Maximum number of batched tokens per ' 'iteration.') + parser.add_argument( + '--min-chunk-size', + type=int, + default=EngineArgs.min_chunk_size, + help= + 'For chunked prefill, the minimum number of tokens from a single ' + 'prompt to process in a single iteration. Must be less than or ' + 'equal to --max-num-batched-tokens.') parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, @@ -1136,7 +1145,9 @@ def create_engine_config(self) -> VllmConfig: multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), - policy=self.scheduling_policy) + policy=self.scheduling_policy, + min_chunk_size=self.min_chunk_size) + lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index f86c6ec362eb..644c94b6affd 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -986,7 +986,9 @@ def get_logprobs( if len(query_indices) == 0: empty_sampled_logprob: SampleLogprobs = [] empty_prompt_logprob: Optional[PromptLogprobs] = None - return [empty_prompt_logprob], [empty_sampled_logprob] + num_seq_groups = len(sampling_metadata.seq_groups) + return [empty_prompt_logprob + ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups selected_logprobs, ranks = None, None top_logprobs, top_token_ids = None, None