[Intel GPU] Fix incorrect KV-cache page table for local attention when page_size > 1#23757
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces logic to normalize the page table within the _init_local_attn_metadata function of the XPU backend, ensuring it is at page-granularity when page_size is greater than one. Feedback suggests refactoring this logic into a shared helper method to eliminate duplication with similar code found in init_forward_metadata.
| if self.page_size > 1: | ||
| strided_indices = torch.arange( | ||
| 0, page_table.shape[1], self.page_size, device=page_table.device | ||
| ) | ||
| page_table = page_table[:, strided_indices] // self.page_size |
There was a problem hiding this comment.
The normalization logic added here (striding and floor-dividing the page table) is identical to the logic at lines 371-376 in init_forward_metadata. While necessary here because _init_local_attn_metadata is called before that main normalization block, it would be cleaner to encapsulate this logic into a helper method or move the main normalization earlier in init_forward_metadata to avoid duplication and ensure consistency across the backend.
|
@sunjiweiswift please help review this one! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
we'd better add UT to cover radix attention level behavior, like this one:
@ckvermaAI could you please take this job? |
|
@mingfeima sure, Let me add similar unit test |
|
Let me test this on more models. |
@sunjiweiswift any updates? |
Motivation
Fixes incorrect KV-cache page table values passed to
make_local_attention_virtual_batchesin the XPU (Intel GPU) attention backend whenpage_size > 1and local (chunked) attention is enabled. This bug caused incorrect/zeroed outputs fromflash_attn_with_kvcache.Root Cause
make_local_attention_virtual_batchesexpects a page-granularity block table where each column p stores the physical page index for logical pagep. However, the rawreq_to_tokentable is token-granularity (columni= KV slot for tokeni).When
page_size > 1, the un-strided token-granularity table was passed directly, causingblock_starts = k_seqstarts_absolute // page_sizeto index incorrect physical page values.Modifications
When
page_size > 1, the page table is first converted to page-granularity by:page_size-th column (torch.arange(0, ..., page_size))page_sizeto convert KV slot indices to physical page indicesThis mirrors the normalization already applied to
metadata.page_tableelsewhere in the backend.Changes
python/sglang/srt/layers/attention/xpu_backend.py: Add page table stride+divide normalization before passing tomake_local_attention_virtual_batcheswhen page_size > 1Accuracy Tests
GSM8k benchmark on XPU with
page_size > 1and local attention enabled:Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): Run #26011291231⚠️ Not enabled — add
Latest PR Test (Extra):
run-ci-extralabel to opt in.