Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def __init__(self, runner: "GPUModelRunner"):
self.runner = runner

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
scheduler_output: "SchedulerOutput") -> None:
pass

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
Expand Down
16 changes: 5 additions & 11 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def __init__(self,
self.page_size = self.runner.block_size

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
scheduler_output: "SchedulerOutput") -> None:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
Expand Down Expand Up @@ -415,20 +415,16 @@ def reorder_batch(self, input_batch: "InputBatch",
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False

for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break

input_batch.swap_states(prefills[i - 1], decode_idx)

# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
Expand All @@ -437,8 +433,6 @@ def reorder_batch(self, input_batch: "InputBatch",
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens

return modified_batch

def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor, seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
Expand Down
14 changes: 6 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.condense(removed_req_indices)

if batch_changed:
# Some attention backends (namely MLA) may want to separate
# requests based on if the attention computation will be
# compute-bound or memory-bound. This gives them a hook to do that.
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)

self.input_batch.refresh_sampling_metadata()

def _prepare_inputs(
Expand All @@ -471,14 +477,6 @@ def _prepare_inputs(
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0

# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
Expand Down