Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Multi-Step + Single Step Prefills via Chunked Prefill code path #8378

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel(
slot_mapping_ptr[cur_query_id] = slot_num;
}

inline void verify_tensor(std::string const& name, torch::Tensor& t,
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
Expand Down
9 changes: 9 additions & 0 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("is_async", [True])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
@pytest.mark.parametrize("enable_chunked_prefill", [True, False])
@pytest.mark.asyncio
async def test_multi_step(
example_prompts,
Expand All @@ -49,6 +50,7 @@ async def test_multi_step(
is_async: bool,
num_logprobs: Optional[int],
attention_backend: str,
enable_chunked_prefill: bool,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
Expand All @@ -74,6 +76,10 @@ async def test_multi_step(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
if enable_chunked_prefill and \
(pp_size > 1 or attention_backend != "FLASH_ATTN"):
pytest.skip("Multi-step with Chunked-Prefill only supports"
"PP=1 and FLASH_ATTN backend")

override_backend_env_variable(monkeypatch, attention_backend)

Expand All @@ -93,6 +99,9 @@ async def test_multi_step(
if eager_mode:
ms_server_args.append("--enforce-eager")

if enable_chunked_prefill:
ms_server_args.append("--enable-chunked-prefill")

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
Expand Down
4 changes: 4 additions & 0 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
Expand All @@ -28,6 +29,7 @@ def test_multi_step_llm(
model: str,
dtype: str,
tp_size: int,
enable_chunked_prefill: bool,
max_tokens: int,
enforce_eager: int,
num_scheduler_steps: int,
Expand All @@ -51,6 +53,7 @@ def test_multi_step_llm(
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
enable_chunked_prefill: chunked-prefill on/off
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
Expand All @@ -73,6 +76,7 @@ def test_multi_step_llm(
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
enable_chunked_prefill=enable_chunked_prefill,
num_scheduler_steps=num_scheduler_steps,
) as vllm_model:
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
Expand Down
32 changes: 27 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
Expand All @@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
assert num_seqs > num_queries
assert self.use_cuda_graph

if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
comaniac marked this conversation as resolved.
Show resolved Hide resolved
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1

self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
Expand All @@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
Expand Down Expand Up @@ -704,8 +724,10 @@ def forward(

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa

# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
Expand Down
20 changes: 12 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,22 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]:

return self

def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")

assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
Expand Down
13 changes: 10 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,16 @@ def __init__(self,
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
if num_scheduler_steps > 1:
# Multi-step Chunked-Prefill doesn't allow prompt-chunking
# for now. Have max_num_batched_tokens set to max_model_len
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
max_num_batched_tokens = max(max_model_len, 2048)
comaniac marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply set max_num_batched_tokens = max_model_len? What's the reason for 2048 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an else part to if enable_chunked_prefill - there we do,

                # If max_model_len is too short, use 2048 as the default value
                # for higher throughput.
                max_num_batched_tokens = max(max_model_len, 2048)

I replicated the same. This argument is the token-buget in the Scheduler. I believe it is so we can schedule more prefills and not be limited by the small max_model_len value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, ok

else:
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
max_num_batched_tokens = 512
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
Expand Down
13 changes: 9 additions & 4 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def __init__(
self._num_full_slots = self._get_num_token_ids()

@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
def get_num_required_blocks(token_ids: List[int],
block_size: int,
num_lookahead_slots: int = 0) -> int:
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
sequence of token IDs along with any look-ahead slots that may be
required (like in multi-step + chunked-prefill).

This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Expand All @@ -66,12 +69,14 @@ def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
num_lookahead_slots (int): look-ahead slots that the sequence may
require.

Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
sequence of token IDs along with any required look-ahead slots.
"""
return cdiv(len(token_ids), block_size)
return cdiv(len(token_ids) + num_lookahead_slots, block_size)

def allocate(self,
token_ids: List[int],
Expand Down
7 changes: 6 additions & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,15 @@ def __init__(
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.

assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"

check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)

self_num_required_blocks = self._get_seq_num_required_blocks(
Expand Down
5 changes: 4 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(
self._last_access_blocks_tracker = LastAccessBlocksTracker(
self.block_allocator)

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.

Expand All @@ -117,6 +119,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
num_lookahead_slots=num_lookahead_slots,
)

if seq_group.is_encoder_decoder():
Expand Down
4 changes: 3 additions & 1 deletion vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def __init__(
) -> None:
pass

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
# Always return OK for dummy purposes
return AllocStatus.OK

Expand Down
4 changes: 3 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def get_block_space_manager_class(version: str):
raise ValueError(f"Unknown version {version=}")

@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
pass

@abstractmethod
Expand Down
Loading
Loading