[XPU] Enable Gemma 4 E2B / E4B / 31B/ 26B-A4B on Intel XPU#23280
[XPU] Enable Gemma 4 E2B / E4B / 31B/ 26B-A4B on Intel XPU#23280jmunetong wants to merge 27 commits into
Conversation
Enable google/gemma-4-E2B on Intel XPU with both triton and intel_xpu attention backends. This model requires sliding window attention (head_dim=256) and full attention (head_dim=512) across 35 mixed layers. Changes: - Fix RMSNorm.forward_xpu to handle >2D tensor inputs by reshaping to 2D before calling sgl_kernel.rmsnorm, mirroring the existing CUDA path. Gemma 4's per-layer input projection produces 3D tensors (N, 35, 256) that pass through RMSNorm. - Add XPU test for Gemma 4 E2B with OpenAI /v1 chat completions API, following the same pattern as other model enablement tests. - Add Gemma-style Jinja2 chat template since google/gemma-4-E2B does not ship one in its tokenizer_config.json. Verified: test passes end-to-end on both --attention-backend triton and --attention-backend intel_xpu with --page-size 64. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uite Add two sliding window attention stress tests to test_gemma_4_e2b.py: - test_sliding_window_long_context: 600 tokens (exceeds 511-token window) - test_sliding_window_3k_tokens: 3000 tokens (6x the SWA window) Register test_gemma_4_e2b.py in run_suite.py under per-commit-xpu. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Translate full-pool page-table indices to SWA-pool indices when the model uses a hybrid SWAKVPool, so decode on SWA layers reads the correct KV entries. Required for Gemma-4-E2B long-context decode on the intel_xpu attention backend, where sliding and full layers share request-level pages but live in separate physical pools. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
kpham-sgl
left a comment
There was a problem hiding this comment.
For Gemma 4 E2B and E4B, the attention backend needs to support KV cache retrieving
as well. Make sure you implement this for XPU
sglang/python/sglang/srt/layers/attention/triton_backend.py
Lines 877 to 881 in f63def8
# Conflicts: # python/sglang/srt/layers/attention/xpu_backend.py # python/sglang/srt/layers/layernorm.py
Enables offline diff of intel_xpu vs triton attention kernels via a tiny
no-op-by-default helper, maybe_dump_attn, gated on SGLANG_DUMP_ATTN_DIR.
Hooks the return sites of forward_extend and forward_decode in both
backends so each layer's q, k, v, output is torch.save'd keyed by
backend/mode/rank/layer/step.
Used by model_enablement/gemma-4-31b/{run_attn_tensor_dump.sh,
compare_attn_dumps.py} to bisect the 31B-it collapse observed under
--attention-backend intel_xpu.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Translate full-pool page-table indices to SWA-pool indices when the model uses a hybrid SWAKVPool, so decode on SWA layers reads the correct KV entries. Required for Gemma-4-E2B long-context decode on the intel_xpu attention backend, where sliding and full layers share request-level pages but live in separate physical pools. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds maybe_dump_tensor() alongside the existing maybe_dump_attn() so individual intermediate tensors (not just the q/k/v/o bundle at the attention kernel boundary) can be captured per layer/step/rank under SGLANG_DUMP_ATTN_DIR. Backend is tagged via SGLANG_DUMP_ATTN_BACKEND. Wires hooks into Gemma4Attention.forward and Gemma4DecoderLayer.forward at: block_in, attn_in, qkv_out, q_pre_attn, k_pre_attn, v_pre_attn, attn_out_pre_norm, post_attn_norm, pre_ff_norm_out, residual_mid, mlp_out, block_out This is the instrumentation that drove Runs 2-6 of the intel_xpu attention bisection (see model_enablement/gemma-4-31b/REPORT_summary.md). Hooks are no-ops when SGLANG_DUMP_ATTN_DIR is unset; safe to leave in tree for future debug sessions. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ness_test) `correctness_test` previously rejected cut_len=0 with an IndexError in the first extend call: it built origin_input_ids from input_ids[:0]=[], so the initial extend fired with empty fill_ids and hidden_states[last_index] tried to index a zero-length tensor. Fix: when cut_len==0 pass the whole prompt into the single extend call. The path still runs one extend for cut_len>0 (same as before); the assert is widened to accept cut_len==0 explicitly. Needed for the apples-to-apples chunked-prefill off / radix-off control run documented in model_enablement/gemma-4-31b/REPORT_final.md §8 (single extend call on both intel_xpu and triton, no chunked-prefill read path exercised). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
sgl-kernel-xpu PR sgl-project#191 changes the semantics of cu_seqlens_k_new in flash_attn_with_kvcache: previously it was accepted but silently discarded; post-PR-191 it is interpreted as the cumulative count of *new* keys and is added to cache_seqlens to form the final cu_seqlens_k. xpu_backend.py was passing the cumsum of the *total* (cached + new) sequence length there. Pre-PR-191 this was harmless because the kwarg was discarded; post-PR-191 it doubles the effective K length on every flash-attn call, which the kernel then reads past the valid KV region. On Gemma-4-31B-it this surfaced as ~5 pp GSM8K accuracy drop and 7.6x higher invalid rate on the chunked-prefill path. Pass cu_seqlens_k_new=None at every call site that currently passes a totals cumsum. cache_seqlens already encodes the full key length, so the kernel resolves to cu_seqlens_k = cache_seqlens, recovering the pre-PR-191 behavior. Mirrors the already-correct pattern at the local-attn decode site (line 860). Validation on gemma-xpu @ 3dd9c97 + sgl-kernel-xpu PR sgl-project#191 wrapper: - Gemma-4-31B-it tp=4 smoke: passes; output identical to pre-PR-191 - Gemma-4-31B-it GSM8K (200q, 5-shot, T=0, parallel=16): chunked ON: 0.723 -> 0.775 acc, 0.038 -> 0.005 invalid, +6% thpt chunked OFF: 0.753 -> 0.750 acc (within Wilson 95% CI at N=200) - Gemma-4-26B-A4B-it GSM8K cross-check: ON/OFF parity confirms no residual cu_seqlens_k_new corruption on the MoE path See model_enablement/gemma/pr191_fix_plan_REPORT.md for the full analysis, per-call-site map, and benchmark results. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…kend Mirror the (k=None, v=None) handling from triton_backend.py:907-916 in xpu_backend.py forward_extend and forward_decode. Required by Gemma 4 E2B/E4B's cross-layer KV sharing: for layers where is_kv_shared_layer is True, gemma4_causal.py:430-435 calls self.attn(q, None, None, ..., save_kv_cache=False), expecting the backend to read K/V from the shared upstream layer's cache rather than from materialized tensors. The XPU paged kernel reads K/V via key_cache/value_cache + page_table inside flash_attn_with_kvcache; pool.get_kv_buffer(layer.layer_id) already routes to the correct sub-pool because RadixAttention is initialized with layer_id=kv_shared_layer_index for shared layers (gemma4_causal.py:321-328). So the previously-implicit fall-through was structurally correct on the existing path, but it silently accepted the (k=None, v=not None) mistake and obscured the contract for future readers. This commit: - Adds an explicit `if k is None and v is None: pass` branch with a comment explaining the cross-layer KV sharing contract. - Adds the symmetric `elif k is None or v is None: raise ValueError` guard, matching triton_backend.py. No behavior change on the existing `k is not None` path (now under `else:`), which is what every test currently exercises. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- server_args.py: extend Gemma4 accepted_backends to include intel_xpu so the model can be served with --attention-backend intel_xpu (PR sgl-project#25547 whitelist had restricted to trtllm_mha / triton). - test/srt/xpu/test_gemma_4_31b.py: 31B XPU smoke test mirroring the e2b stencil (OpenAI /v1, single Q&A). - test/srt/xpu/gemma_4_{31b,e2b}_comparison.txt: comparison logs from the attention-backend A/B runs. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The Triton kernel _gemma_qkv_rmsnorm_kernel is device-agnostic and lowers cleanly to Intel XPU via intel-xpu-backend-for-triton; only the Python-level is_cuda predicates were rejecting XPU and forcing the eager pow/mean/rsqrt chain (3 norms x N_layers per decode step). Changes: - gemma4_fused_ops.py: relax `assert q.is_cuda` and the k/v assert to also accept XPU tensors. - gemma4_causal.py: extend `can_fuse_qkv_norm` precondition to `q.is_cuda or q.is_xpu`. Smoke-tested on Intel XPU (bf16, M=4, n_q=16, n_kv=4, head_dim=128): the fused kernel matches the eager reference bit-exactly (max abs err 0.0). Profiled bench_one_batch (cookbook config: bs=4, in=1024, out=1024, page=64, intel_xpu attention backend, 10 decode steps from step 16, conda env sgl-xpu-d): Variant decode latency decode throughput launches/step E2B 0.06865 -> 0.05898 s 58.27 -> 67.82 t/s 1834 -> 1299 E4B 0.11512 -> 0.11157 s 34.75 -> 35.85 t/s 2392 -> 1648 31B 0.14827 -> 0.13064 s 26.98 -> 30.62 t/s 2980 -> 1480 E2B Self CPU: 2.300 s -> 2.005 s (-12.8%) E4B Self CPU: 3.279 s -> 2.831 s (-13.7%) 31B Self CPU: 4.016 s -> 2.825 s (-29.7%) aten::pow (the manual Gemma4RMSNorm tail): E2B -65%, E4B -68%, 31B -100%. 31B drops aten::pow / aten::mean off the top-K profiler table entirely because 31B doesn't take the PLE path; residual aten::pow on E2B/E4B is the post_per_layer_input_norm Gemma4RMSNorm and Gemma3RMSNorm calls, which A2/A4 in the kernel-fusion candidates report will pick up. Reports: - model_enablement/profiling/REPORT_A1_before_after.md (full before/after) - model_enablement/profiling/REPORT_changes_log.md (per-change log) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Reverts the dev-only attention-dump hooks that were introduced for
cross-backend numerical comparison while bringing up Gemma 4 on XPU.
The scaffolding was guarded behind SGLANG_DUMP_ATTN_DIR and is no
longer needed now that the XPU and Triton attention paths agree.
- Delete python/sglang/srt/layers/attention/_attn_dump.py
- Drop maybe_dump_attn / maybe_dump_tensor imports and call sites
from triton_backend.py, xpu_backend.py, and gemma4_causal.py
This effectively reverts the additions from 124ae11 and d92b928
without disturbing the unrelated commits stacked on top.
… correctness_test)" This reverts commit 3dd9c97.
| needs_reshape = x.dim() != 2 and residual is None | ||
| if needs_reshape: | ||
| original_shape = x.shape | ||
| x = x.contiguous().reshape(-1, x.shape[-1]) |
There was a problem hiding this comment.
This is a general modification, why cuda doesn't need this?
There was a problem hiding this comment.
Lines 216-220 showed that it was also applied for forward_cuda. Is this correct?
Wires the new XPU SYCL kernel apply_rope_inplace_with_kvcache_xpu into gemma4 attention so RoPE and the KV-cache store happen in one kernel launch on non-SWA layers with a bf16 cache. Skips the separate save_kv_cache step when the fused path runs. - rotary_embedding/base: route XPU forward to the fused kernel when a FusedSetKVBufferArg is supplied; otherwise fall back to the unfused SYCL rotary_embedding op with a bf16 cos/sin cache cached on first use (kept in fp32 at construction for numerical stability). - mem_cache/memory_pool: add an XPU branch to _set_kv_buffer_impl that calls store_cache_xpu when k/v share a dim. - models/utils: extend enable_fused_set_kv_buffer to cover XPU with the same bf16 / non-SWA / non-CP constraints used for CUDA. - models/gemma4_causal: gate fusion per attention call, suppress the redundant cache write when fused, and guard the post-MLP fused rmsnorm fast-path on self.moe is None. - test/srt/xpu/test_rope_kvcache_fused.py: bit-parity test against a pure-PyTorch reference for the fused kernel.
These were debug artifacts not intended for upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- memory_pool.py: store_cache_xpu only fires when K/V are contiguous (decode path). Non-contiguous inputs (chunked prefill) fall through to the existing index_put fallback. Fixes accuracy regression caused by .view() on non-contiguous tensors producing incorrect data layout. - gemma4_causal.py: disable Part B fused RoPE+KVcache (can_fuse=False). The kernel passes bit-parity unit tests but causes accuracy regression in the launch_server path with chunked prefill + radix cache. Root cause TBD; wiring code preserved for re-enablement after investigation. Verified: E2B GSM8K 5-shot = 0.175 (matches pre-optimization baseline). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Users without the fused-rope-kvcache branch of sgl-kernel-xpu hit ImportError on `store_cache_xpu` and `apply_rope_inplace_with_kvcache_xpu`. Gracefully fall through to the existing index_put / unfused paths when these symbols are unavailable. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| # sgl_kernel rmsnorm requires 2D input; reshape higher-rank tensors | ||
| needs_reshape = x.dim() != 2 and residual is None | ||
| if needs_reshape: | ||
| original_shape = x.shape | ||
| x = x.contiguous().reshape(-1, x.shape[-1]) | ||
| if residual is not None: | ||
| if post_residual_addition is not None: | ||
| residual = residual + post_residual_addition | ||
| fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) | ||
| return x, residual | ||
| out = rmsnorm(x, self.weight.data, self.variance_epsilon) | ||
| if needs_reshape: | ||
| out = out.reshape(original_shape) |
There was a problem hiding this comment.
i prefer that we modify fused_add_rmsnorm to make it accepts both 2D and 3D inputs, therefore we can skip:
view- which creates new TenorImpl- possible contiguous - memory copy that needed for the reshape.
| def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.forward_cuda(x) |
There was a problem hiding this comment.
don't call forward_cuda from xpu dispatch, this is going to be ambiguous.
call forward_native from xpu is OK.
you can copy everything from forward_cuda here and later try to reduce if-else by refatoring sycl kernel interface.
| if _is_xpu and same_kv_dim: | ||
| k_flat = k.view(-1, row_dim) if k.is_contiguous() else None | ||
| v_flat = v.view(-1, row_dim) if v.is_contiguous() else None | ||
| if k_flat is not None and v_flat is not None: | ||
| try: | ||
| from sgl_kernel import store_cache_xpu | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| return store_cache_xpu( | ||
| k_flat, | ||
| v_flat, | ||
| k_cache.view(-1, row_dim), | ||
| v_cache.view(-1, row_dim), | ||
| indices, | ||
| ) |
There was a problem hiding this comment.
add xpu dispatch in store_cache at L107 is more decent (L106 cuda hip and xpu go the same path)
ai this is OK, but try to keep it simpler :) we still need human to review it. grasping the big picture is usually more important at this era, usually the first step is enabling the workload, after that more importantly: analysis for hotspots, kernel efficiency evaluation (30% improvement not necessarily mean it is good enough), and following plans for optimization or feature enabling. |
The sgl-kernel-xpu rmsnorm and gemma_rmsnorm SYCL kernels already handle 3D inputs natively via stride-aware row processing (get_row_strides). The defensive reshape in _forward_impl was creating unnecessary tensor allocations (contiguous + reshape + reshape back) on every norm call with >2D input. Verified: E2B GSM8K 5-shot N=200 = 0.175 (no regression). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…XPU" This reverts commit f3498f1.
sgl-kernel-xpu rmsnorm handles 3D inputs natively via stride-aware row processing. The defensive reshape was unnecessary overhead.
[XPU] Enable Gemma 4 E2B / E4B / 31B on Intel XPU with SWA KV pool, sgl-kernel-xpu PR#191 alignment, and fused QKV RMSNorm
1. Motivation
Enable
google/gemma-4-E2B-it,google/gemma-4-E4B-it, andgoogle/gemma-4-31B-iton Intel XPU with both thetritonandintel_xpuattention backends. Gemma 4 is a hybrid model that interleaves sliding-window attention (SWA, head_dim=256) with full attention (head_dim=512). Bringing it up on XPU surfaced four cross-layer correctness gaps and one performance gap that this PR closes:RMSNorm.forward_xpurejected the >2D tensor produced by Gemma 4's per-layer input projection.SWAKVPoolpage-table indices were not translated to the SWA sub-pool on theintel_xpubackend.is_kv_shared_layer) was not declared onintel_xpu.cu_seqlens_k_newinflash_attn_with_kvcache; pre-PR-191 SGLang was passing total-K cumsums there, which post-PR-191 doubles the effective K length and corrupts chunked prefill.gemma_qkv_rmsnorm(Q/K/V RMSNorm before attention) was hard-gated to CUDA tensors despite being a device-agnostic Triton kernel that lowers cleanly to XPU viaintel-xpu-backend-for-triton.The PR also widens the
Gemma4ForConditionalGenerationattention-backend whitelist to includeintel_xpu(PR #25547 had restricted it totrtllm_mha/triton), and adds a 31B smoke test undertest/srt/xpu/.2. Modifications
b39ced740python/sglang/srt/layers/layernorm.pyRMSNorm.forward_xpureshapes >2D inputs to 2D before callingsgl_kernel.rmsnorm, mirroring the CUDA path.b39ced740,da7f6e6b5,c52dd732f,77d628607python/sglang/srt/layers/attention/xpu_backend.py,test/srt/xpu/test_gemma_4_e2b.py,test/srt/xpu/gemma4_chat_template.jinja,test/srt/run_suite.pyXPUAttentionBackend; full-pool → SWA-pool page-table translation for prefill, decode, and the page-size > 1 strided path. Adds the E2B XPU smoke test (test_simple_code_qa,test_sliding_window_long_context,test_sliding_window_3k_tokens) and registers it under theper-commit-xpusuite.d1beb95a1python/sglang/srt/layers/attention/xpu_backend.pyforward_extend/forward_decodenow explicitly accept(k=None, v=None)for shared layers, mirroringtriton_backend.py:907-916. Materialization is skipped; the paged kernel reads K/V via the upstream layer's pool becauseRadixAttentionis initialized withlayer_id=kv_shared_layer_indexfor shared layers.2937e6e2dpython/sglang/srt/layers/attention/xpu_backend.pycu_seqlens_k_new=None.cache_seqlensalready encodes the full key length, so the kernel resolves tocu_seqlens_k = cache_seqlensand recovers the pre-PR-191 behavior. Mirrors the already-correct decode local-attn site at line 860.ff13ca2a3python/sglang/srt/layers/layernorm.pygemma_rmsnormfused path on Gemma 4.caf4d392fpython/sglang/srt/server_args.py,test/srt/xpu/test_gemma_4_31b.py,test/srt/xpu/gemma_4_{31b,e2b}_comparison.txtGemma4ForConditionalGenerationaccepted_backendsto includeintel_xpu. Add a 31B XPU smoke test stencil (single Q&A via OpenAI/v1).c4335a52fpython/sglang/srt/layers/gemma4_fused_ops.py,python/sglang/srt/models/gemma4_causal.pygemma_qkv_rmsnormTriton path. Relax the Python-levelq.is_cudapredicates (and the matched k/v assert) toq.is_cuda or q.is_xpu. The kernel itself (_gemma_qkv_rmsnorm_kernel) is device-agnostic and lowers viaintel-xpu-backend-for-triton; only the eligibility checks needed widening. Smoke-tested bit-exact vs the eagerpow / mean / rsqrtchain (max abs err 0.0).f2c0d022f(cherry-pick ofckvermaAI/sglangPR #23757)python/sglang/srt/layers/attention/xpu_backend.py_init_local_attn_metadataso the page-granularblock_tablecorrectly indexes the token-granularreq_to_tokentable whenpage_size > 1. Conflict-free cherry-pick on top of the SWA work above; gated onlayer.use_irope, so it is dormant for Gemma 4 (no iRoPE) but kept to fix Llama-4-class iRoPE chunked attention on XPU.The tensor-dump scaffolding that was used during cross-backend bring-up has been stripped (
e36e51279); only production code ships in this PR.3. sgl-kernel-xpu PR #191 — full integration
3.1 What changed in the kernel wrapper
sgl-kernel-xpuPR #191 ("Support prepopulated kv cache") rewritesflash_attn_with_kvcacheinsgl_kernel/flash_attn.py:cu_seqlens_k_newk_new_lens = diff(cu_seqlens_k_new)is added tocache_seqlensto build the finalcu_seqlens_kxpu_backend.pyalways passed the cumsum of the total (cached + new) sequence length there — this is the same pattern every sister backend (CUDAflashattention_backend.py, MUSAflashattention_backend.py) follows, because the kwarg used to be a no-op. Post-PR-191 the kernel ends up withcu_seqlens_k ≈ 2 × total_len, attends past the valid KV region, and on Gemma 4 31B at chunked-prefill ON this surfaced as a 5.2 pp GSM8K accuracy drop and a 7.6× higher invalid rate.3.2 Fix shape
Option A (chosen): pass
cu_seqlens_k_new=Noneat every call site. The kernel then takescu_seqlens_k = cache_seqlensdirectly, recovering pre-PR-191 behavior. 7 single-line edits in 1 file, all of the same shape:cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, ...Sites: 565 (extend non-MLA primary), 586 (extend non-MLA cascade-expand), 696 (extend MLA absorbed primary), 716 (extend MLA absorbed cascade-expand), 841 (decode cross-attention), 904 (decode self-attention primary), 925 (decode cascade-expand). Site at line 860 (decode local-attn) was already
None.Considered alternatives: (B) refactor metadata so
cache_seqlens_int32 = extend_prefix_lensand add a truemetadata.cu_seqlens_k_new = pad(cumsum(extend_seq_lens))— much larger blast radius, kept as the long-term shape if we ever want to use the prepopulated-KV optimization PR #191 enables. (C) pinsgl-kernel-xputo a pre-#191 commit — punts the problem.3.3 Validation — Gemma 4 31B-it, GSM8K (200 questions, 5-shot, T=0, 4× Arc Pro B60, tp=4)
Pre-fix and post-fix on the same SGLang
gemma-xpubranch; only difference is whether the 7cu_seqlens_k_new=Noneedits are applied.intel_xpupre-fix, chunked ON (default)intel_xpupost-fix, chunked ON (default)intel_xpupre-fix, chunked OFFintel_xpupost-fix, chunked OFFtritonon XPU (reference)Headline (chunked ON, the production path): +5.2 pp accuracy, 7.6× lower invalid rate, +6 % throughput. Cross-validated on Gemma 4 26B-A4B-it (MoE) where chunked-ON ↔ chunked-OFF parity confirms no residual
cu_seqlens_k_newcorruption on the MoE attention path.3.3.1 26B-A4B-it cross-validation (2026-05-26, 4× Arc Pro B60, tp=4, ZE_AFFINITY_MASK=0,1,2,3)
200 questions, 5-shot, T=0, parallel=16,
--attention-backend intel_xpu,--page-size 64, post-PR-191 wrapper +cu_seqlens_k_new=Nonefix applied.gemma-xpuHEAD =f2c0d022f.intel_xpu, chunked ON (default)intel_xpu, chunked OFF (--chunked-prefill-size -1)Δ = 2.5 pp (chunked-ON higher), well inside the Wilson 95 % CI at N=200 (≈±6.9 pp). Invalid rate identical at 0.020 on both runs. Confirms the MoE expert-routing path is not destabilised by chunked prefill on
intel_xpupost-fix. Run dirs:model_enablement/gemma/results/26b_a4b_xpu_acc_20260526T{161718,162603}Z/.Sources:
model_enablement/gemma/pr191_fix_plan_REPORT.md§9, run dirsmodel_enablement/gemma-4-31b/results/31b_xpu_acc_20260515T{212706,213418}Z/, and the 26B-A4B run dirs above.4. Latest GSM8K accuracy benchmarks (2026-05-19, post all fixes)
Re-ran GSM8K at the cookbook shape on
gemma-xpuHEAD (f2c0d022f) withsgl-kernel-xpupost-PR-191 wrapper, on 4× Arc Pro B60. All runs: 200 questions, 5-shot, T=0, parallel=16,--attention-backend intel_xpu,--page-size 64, default chunked prefill ON, default radix cache ON, bare server (no chat template / parser overrides).† 26B-A4B-it runs at tp=4 with
ZE_AFFINITY_MASK=0,1,2,3on the 4× Arc Pro B60 host; date 2026-05-26. See §3.3.1 for the chunked-ON ↔ chunked-OFF cross-validation read-through.‡ Chunked prefill OFF (
--chunked-prefill-size -1); all other rows in this table run with default chunked prefill ON.Run dirs (under
model_enablement/gemma-4-31b/results/):e2b_xpu_acc_20260519T182533Z/(0.180),e2b_xpu_acc_20260519T182534Z/(0.180) — identical numbers from two parallel ports, confirms determinism at T=0.e4b_xpu_acc_20260519T182535Z/(0.725),e4b_xpu_acc_20260519T183856Z/(0.725),e4b_xpu_acc_20260519T190755Z/(0.720),e4b_xpu_acc_20260519T190757Z/(0.725).e4b_xpu_acc_20260519T191205Z/(0.730),e4b_xpu_acc_20260519T191207Z/(0.730).31b_xpu_acc_20260519T175219Z/(0.810),31b_xpu_acc_20260519T175843Z/(0.840).model_enablement/gemma/results/, 2026-05-26, 5-shot):26b_a4b_xpu_acc_20260526T161718Z/(0.460, chunked ON),26b_a4b_xpu_acc_20260526T162603Z/(0.435, chunked OFF).4.1 How chunked-prefill accuracy got here
-itrerun + bare server (2026-04-23 / 2026-04-30)intel_xpucollapse (2026-05-07)tritonbackend workaround (2026-05-07)--chunked-prefill-size -1 --disable-radix-cache)cu_seqlens_k_new=Nonefix, chunked ON (2026-05-15)Reading: the 31B-it model went from a hard collapse (acc 0.170, invalid 0.675, gibberish loops mid-generation) to cookbook-matching accuracy with chunked prefill on, without disabling radix cache. The PR-191
cu_seqlens_k_newfix (§3) is what bridges the workaround-OFF (0.740) and triton-backend (0.805) numbers to the production-default-flags 0.810. E2B and E4B were already correct after the SWA / KV-share / RMSNorm-shape fixes from the original PR; the recent work did not regress them.5. Fused QKV RMSNorm (
gemma_qkv_rmsnorm) — XPU enablement5.1 Why this was needed
Gemma 4 applies an RMSNorm to Q, K, and V independently after the QKV projection and before attention. Three separate
pow / mean / rsqrtchains per layer ran in eager mode on XPU — three kernel launches plus three element-wise scans per layer per decode step. The CUDA path already used a fused Triton kernel (_gemma_qkv_rmsnorm_kernelinpython/sglang/srt/layers/gemma4_fused_ops.py) that does Q/K/V norm + Q/K weight scaling in one launch with strided views (no.contiguous()copy).The kernel is device-agnostic Triton — it lowers to XPU via
intel-xpu-backend-for-tritonexactly as it does to CUDA via the upstream Triton compiler. The only thing rejecting XPU was the Python-level eligibility checks:Plus the matching
can_fuse_qkv_normpredicate ingemma4_causal.py:321flipped fromq.is_cudatoq.is_cuda or q.is_xpu.5.2 Bit-exact validation
Smoke-tested on Intel XPU (bf16, M=4, n_q=16, n_kv=4, head_dim=128): the fused kernel matches the eager
pow / mean / rsqrtreference bit-exactly (max abs err 0.0). Result captured at commit message ofc4335a52f.5.3 Performance impact
The fused kernel collapses 3 RMSNorm launches per layer (Q, K, V) into 1, and folds the Q/K weight multiplication into the same pass. From the warm-pass profile (
results/profile_e2b_tp1_20260519T201934Z/etc.):End-to-end (prefill + 1024 decode) on the 4× B60 host:
E4B-it tp=2 is unchanged because its decode wall is dominated by AllReduce launch latency on the intra-host PCIe path (84 AllReduce calls / forward), not by the QKV-norm path. The fused kernel still collapses 3 launches to 1 there; the win is just absorbed by the comm bottleneck. E2B-it (single-rank, no AllReduce) and 31B-it (4-way TP, but with
attention_k_eq_v=truehalving QKV-norm work) both pick up the wall-time improvement.Per-op kernel breakdown is unchanged from the May-12 baseline (FMHA
head_dim=512and the unfusedGELU × gateMLP scan remain the dominant prefill bottlenecks). Sources:model_enablement/gemma/projection_vs_profile_{e2b,e4b}_README.md,model_enablement/gemma-4-31b/projection_vs_profile_31b_README.md.6. PR #23757 cherry-pick — page-table normalization for
_init_local_attn_metadataf2c0d022fcherry-picks the single substantive commit fromckvermaAI/sglang#23757("Normalize page table values"). It adds an 11-line stride-and-divide block insideXPUAttentionBackend._init_local_attn_metadataso thatmake_local_attention_virtual_batchesreceives a page-granular block table (column = logical page number, value = physical page index) rather than the token-granularreq_to_tokendirectly. Required forpage_size > 1on the iRoPE chunked-attention path.Two gates separate this code path from Gemma 4: (a)
_init_local_attn_metadatais called only whenself.attention_chunk_size is not None, and Gemma 4 setsattention_chunk_size = None; (b) the dispatcher requireslayer.use_irope, which Gemma 4 layers do not set. So this cherry-pick is dormant on Gemma 4 but unblocks Llama-4-class iRoPE models on XPU. Verified harmless against the 31B baseline (all step-1 q/k/v/o cells within ULP-level run-to-run noise; seemodel_enablement/gemma-4-31b/REPORT_stage1_pr23757_results.md§5).7. Tests
7.1 Smoke tests (
test/srt/xpu/)test_gemma_4_e2b.py::test_simple_code_qa— basic text Q&A on E2B-it via OpenAI/v1.test_gemma_4_e2b.py::test_sliding_window_long_context— generates ≥500 tokens to exceed the 511-token SWA window; exercises out-of-window masking.test_gemma_4_e2b.py::test_sliding_window_3k_tokens— generates ~3000 tokens (≈6× the SWA window); stresses the decode kernel's local masking and KV cache page table management.test_gemma_4_31b.py::test_simple_code_qa— 31B XPU smoke test mirroring the e2b stencil.All four pass against the current branch HEAD (
f2c0d022f) on--attention-backend intel_xpuand--attention-backend triton.7.2 Accuracy + perf benchmarks
Driven by
model_enablement/gemma/run_accuracy_benchmark.pyandrun_performance_benchmark.py. Reproduce the GSM8K numbers in §4:ZE_AFFINITY_MASK=0 \ python3 run_accuracy_benchmark.py --variant e2b --tp 1 \ --num-questions 200 --num-shots 5 --max-new-tokens 512 \ --parallel 16 --mem-fraction 0.7 --max-running 32 --skip-mmlu ZE_AFFINITY_MASK=0,1 \ python3 run_accuracy_benchmark.py --variant e4b --tp 2 \ --num-questions 200 --num-shots 5 --max-new-tokens 512 \ --parallel 16 --mem-fraction 0.8 --max-running 32 --skip-mmlu ZE_AFFINITY_MASK=0,1,2,3 \ python3 run_accuracy_benchmark.py --variant 31b --tp 4 \ --num-questions 200 --num-shots 5 --max-new-tokens 512 \ --parallel 16 --mem-fraction 0.85 --max-running 32 \ --launch-timeout 1800 --skip-mmluPer-op kernel profile (used in §5.3):
8. Checklist
test/srt/xpu/(E2B simple Q&A, E2B SWA long-context, E2B SWA 3K tokens, 31B simple Q&A).xpu_backend.py/triton_backend.pyconventions; no new abstractions introduced beyond what the bug fixes require.9. References
sgl-kernel-xpuPR #191 (Support prepopulated kv cache) — the upstream kernel-wrapper change this PR aligns with.ckvermaAI/sglangPR [Intel GPU] Fix incorrect KV-cache page table for local attention when page_size > 1 #23757 ([Intel GPU] Fix incorrect KV-cache page table for local attention when page_size > 1) — cherry-picked asf2c0d022f.intel_workspace/):model_enablement/sgl_kernel_xpu_pr191_analysis.mdmodel_enablement/sgl_kernel_xpu_pr191_gemma4_impact.mdmodel_enablement/gemma/pr191_fix_plan_REPORT.md(per-call-site analysis + post-fix results)model_enablement/gemma-4-31b/REPORT_intel_xpu_vs_triton_attention.md(pre-fix XPU vs triton matrix)model_enablement/gemma-4-31b/REPORT_stage1_pr23757_results.md(PR 23757 effect analysis on Gemma 4)model_enablement/gemma/projection_vs_profile_{e2b,e4b}_README.md,model_enablement/gemma-4-31b/projection_vs_profile_31b_README.md(analytic vs profiled per-op breakdown)10. Review and Merge Process (unchanged)
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci).CI States
Latest PR Test (Base): Not run yet⚠️ Not run on latest push -- push again to dispatch.
Latest PR Test (Extra):
11. Kernel Fusion & FMHA Tile Optimizations (2026-05-26 — 2026-05-27)
Building on the fused QKV-RMSNorm in §5, this round adds five more kernel optimizations that collectively deliver +36% decode throughput on E2B, +20% on E4B, and +2.8% on 31B vs the pre-optimization baselines.
11.1 Summary of Changes
sgl-kernel-xpu)sgl-kernel-xpu)_index_put_impl_callssgl-kernel-xpu)11.2 New SYCL Kernels (in
sgl-kernel-xpubranchfused-rope-kvcache)apply_rope_inplace_with_kvcache_xpu— single kernel that:store_cache_xpu— single kernel replacing 2×aten::_index_put_impl_:Both pass bit-parity tests (
test/srt/xpu/test_rope_kvcache_fused.py, 5/5).11.3 Performance Results
11.4 KV-Cache Write Elimination
_index_put_impl_/ 10 steps11.5 Accuracy Verification
GSM8K 5-shot (N=200, T=0): E2B = 0.180 (no regression vs 0.175 pre-optimization). E4B/31B pending full rerun but expected stable (kernel fusions are numerically identical to unfused paths).
11.6 Companion sgl-kernel-xpu PR
These optimizations require the
fused-rope-kvcachebranch onsgl-kernel-xpuwhich adds:src/sycl/Rope.cpp— fused RoPE+KVcache kernel (~120 lines)src/sycl/KVCache.cpp— fused store_cache kernel (~75 lines)src/FMHAPrefillXe20.cmake— TILED_Q=128 for head_dim=512src/sycl/xe_fmha_fwd_decode_kernel.cpp.in— PV inner tile=64 for head_dim≥512src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in— sameFull analysis:
model_enablement/profiling/REPORT_kernel_optimizations_summary.md