Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

self.block_table_arange = torch.arange(max_num_pages_per_req,
dtype=torch.int32,
device=self.device)
device="cpu")

self.sliding_window = getattr(kv_cache_spec, "sliding_window", None)

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
Expand Down Expand Up @@ -487,16 +489,30 @@ def build(self,
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None

max_num_blocks = block_table_bounds_cpu.max()
block_table_bounds = block_table_bounds_cpu.to(self.device,
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
max_num_blocks = block_table_bounds_cpu.max().item()
arange = self.block_table_arange[:max_num_blocks].unsqueeze(0)
mask = arange < block_table_bounds_cpu.unsqueeze(1)
if (self.sliding_window is not None and not use_cascade
and num_decodes > 0 and
max_num_blocks > self.sliding_window // page_size):
# NOTE(woosuk): Since FlashInfer's decode kernel doesn't skip the kv
# outside the sliding window and only do masking, we manually
# manipulate the seq_lens and block table for skipping.
# NOTE: Don't apply this optimization to prefill requests.
decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
num_skipped_pages = (
torch.relu(decode_seq_lens_cpu - self.sliding_window) //
page_size)

block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
mask[:num_decodes] &= (arange[:num_decodes]
>= num_skipped_pages.unsqueeze(1))
Comment on lines +502 to +509
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The sequence lengths (seq_lens and seq_lens_cpu) are not updated to reflect the truncated block tables for sliding window attention. This can lead to incorrect attention computation or out-of-bounds memory access in the FlashInfer kernel because the kernel will receive mismatched sequence lengths and KV blocks.

The sequence lengths should be reduced by the number of tokens in the skipped pages. Note that if common_attn_metadata is shared across layers, seq_lens and seq_lens_cpu should be cloned before modification to avoid side effects.

            decode_seq_lens_cpu = seq_lens_cpu[:num_decodes]
            num_skipped_pages = (
                torch.relu(decode_seq_lens_cpu - self.sliding_window) //
                page_size)

            skipped_tokens = num_skipped_pages * page_size
            seq_lens_cpu[:num_decodes] -= skipped_tokens
            seq_lens[:num_decodes] -= skipped_tokens.to(self.device,
                                                        non_blocking=True)

            block_table_bounds_cpu[:num_decodes] -= num_skipped_pages
            mask[:num_decodes] &= (arange[:num_decodes]
                                   >= num_skipped_pages.unsqueeze(1))

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. To my understanding, it is not necessary to update seq_lens_cpu in this case.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, it is not necessary to update seq_lens_cpu in this case.

This is true for the "standard" FlashInfer code path which infers seqlen from paged_kv_last_page_len and paged_kv_indptr (so manipulating block_table_bounds_cpu is sufficient). Im not sure if this true for the new trtllm codepaths (i.e. trtllm_batch_decode_with_kv_cache) maybe someone from Nvidia can weigh in, @kushanam @pavanimajety


# write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask,
mask.to(self.device, non_blocking=True),
out=paged_kv_indices)

# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
Expand Down