Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b39ced7
[XPU] Add Gemma 4 E2B model enablement for intel_xpu attention backend
jmunetong Apr 15, 2026
da7f6e6
[XPU] Add SWA long-context tests for Gemma 4 E2B and register in CI s…
jmunetong Apr 16, 2026
c52dd73
[XPU] Wire SWA KV pool translation into intel_xpu attention backend
jmunetong Apr 20, 2026
139c457
Merge branch 'main' into gemma-xpu
jmunetong May 12, 2026
124ae11
attention: add per-layer Q/K/V/O dump hooks for backend comparison
jmunetong May 7, 2026
77d6286
[XPU] Wire SWA KV pool translation into intel_xpu attention backend
jmunetong Apr 20, 2026
d92b928
attention: extend _attn_dump with per-tensor hooks, wire Gemma-4 decoder
jmunetong May 8, 2026
3dd9c97
bench_one_batch: allow --cut-len 0 (single-extend prefill for correct…
jmunetong May 11, 2026
2937e6e
[XPU] Fix cu_seqlens_k_new mis-passing in flash_attn_with_kvcache calls
jmunetong May 15, 2026
2897b90
Merge branch 'sgl-project:main' into gemma-xpu
jmunetong May 15, 2026
d1beb95
[XPU] Support cross-layer KV-cache sharing in intel_xpu attention bac…
jmunetong May 15, 2026
de0e81e
Merge remote-tracking branch 'upstream/main' into gemma-xpu
jmunetong May 18, 2026
ff13ca2
enable sgl-kernel norm for gemma4
airMeng May 18, 2026
caf4d39
[XPU] Allow intel_xpu attention backend for Gemma 4 + add 31B smoke test
jmunetong May 18, 2026
c4335a5
[XPU] Admit XPU into the fused gemma_qkv_rmsnorm path
jmunetong May 19, 2026
e36e512
[XPU] Remove attention tensor-dump scaffolding from gemma-xpu
jmunetong May 20, 2026
5622919
Revert "bench_one_batch: allow --cut-len 0 (single-extend prefill for…
jmunetong May 20, 2026
f2c0d02
Normalize page table values
ckvermaAI Apr 26, 2026
37647ce
Merge branch 'main' into gemma-xpu
jmunetong May 20, 2026
d78c386
enable fusion of post rounting
airMeng May 26, 2026
fd98b73
gemma4-xpu: fuse RoPE + KV-cache write for non-SWA layers
jmunetong May 27, 2026
6818544
test/srt/xpu: remove comparison .txt files from PR
jmunetong May 27, 2026
3a79f23
store_cache_xpu: add contiguity guard; disable Part B fused RoPE+KV
jmunetong May 27, 2026
f8f854e
guard sgl-kernel-xpu imports with try/except for backwards compat
jmunetong May 27, 2026
f3498f1
layernorm: remove 3D→2D reshape for rmsnorm/gemma_rmsnorm on XPU
jmunetong May 27, 2026
1ed9d9f
Revert "layernorm: remove 3D→2D reshape for rmsnorm/gemma_rmsnorm on …
jmunetong May 27, 2026
117c407
layernorm: remove reshape in RMSNorm.forward_xpu
jmunetong May 27, 2026
9c6a8c3
memory_pool: revert to upstream main (moved to separate PR #26594)
jmunetong May 28, 2026
78a0001
Merge upstream/main into gemma-xpu
jmunetong May 28, 2026
bcf09b6
layernorm: Gemma4RMSNorm.forward_xpu self-contained, no forward_cuda …
jmunetong May 28, 2026
cad22db
rotary_embedding: revert to upstream main (moved to separate PR #26595)
jmunetong May 28, 2026
62b184a
test: remove test_gemma_4_31b and test_rope_kvcache_fused from XPU suite
jmunetong May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions python/sglang/srt/layers/attention/xpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
prepare_swa_spec_page_table_triton,
)
from sglang.srt.managers.schedule_batch import get_global_server_args
from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,6 +73,12 @@ def __init__(
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid_swa = model_runner.is_hybrid_swa
self.use_sliding_window_kv_pool = (
isinstance(model_runner.token_to_kv_pool, SWAKVPool)
and model_runner.token_to_kv_pool.swa_layer_nums > 0
)
if self.use_sliding_window_kv_pool:
self.token_to_kv_pool = model_runner.token_to_kv_pool
if self.is_hybrid_swa:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
Expand Down Expand Up @@ -193,6 +200,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata.page_table = self.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]

# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(forward_batch, metadata, device)
elif forward_batch.forward_mode.is_target_verify():
Expand Down Expand Up @@ -373,6 +381,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
),
]

# Translate full-pool indices to SWA-pool indices for hybrid models
if self.use_sliding_window_kv_pool:
metadata.swa_page_table = (
self.token_to_kv_pool.translate_loc_from_full_to_swa(
metadata.page_table
)
)

if self.use_mla:
workspace_size = flash_mla_get_workspace_size(
max_seq_len=self.max_context_len,
Expand All @@ -389,11 +405,25 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
workspace_size, device=self.device, dtype=torch.uint8
)

# Translate full-pool indices to SWA-pool indices for hybrid models
if self.use_sliding_window_kv_pool:
metadata.swa_page_table = (
self.token_to_kv_pool.translate_loc_from_full_to_swa(
metadata.page_table
)
)

# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)

if self.use_sliding_window_kv_pool and metadata.swa_page_table is not None:
metadata.swa_page_table = (
metadata.swa_page_table[:, self.strided_indices] // self.page_size
)

metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
Expand All @@ -413,8 +443,17 @@ def forward_extend(
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
if k is None and v is None:
# Cross-layer KV sharing (Gemma 4): the layer reuses another
# layer's KV cache. The paged kernel reads K/V directly via
# page_table, and pool.get_kv_buffer(layer.layer_id) routes
# to the correct sub-pool because RadixAttention is initialized
# with layer_id=kv_shared_layer_index for shared layers. No
# materialization needed; just skip the write path.
pass
elif k is None or v is None:
raise ValueError("Both k and v should be None or not None")
else:
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
Expand Down Expand Up @@ -497,6 +536,13 @@ def forward_extend(
cu_seqlens_k = swa_spec_metadata.cu_seqlens_k
else:
page_table = metadata.page_table
if is_hybrid_swa and self.use_sliding_window_kv_pool:
if metadata.swa_page_table is not None:
page_table = metadata.swa_page_table
else:
page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa(
metadata.page_table
)
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens = metadata.cache_seqlens_int32
max_seqlen_q = metadata.max_seq_len_q
Expand Down Expand Up @@ -525,7 +571,7 @@ def forward_extend(
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
cu_seqlens_k_new=None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
Expand All @@ -546,7 +592,7 @@ def forward_extend(
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
Expand Down Expand Up @@ -648,7 +694,7 @@ def forward_extend(
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
cu_seqlens_k_new=None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
Expand All @@ -668,7 +714,7 @@ def forward_extend(
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
Expand All @@ -688,7 +734,8 @@ def forward_extend(
else:
o = result

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
out = o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return out

def forward_decode(
self,
Expand All @@ -703,8 +750,12 @@ def forward_decode(
k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
if k is None and v is None:
# Cross-layer KV sharing (Gemma 4): see forward_extend for details.
pass
elif k is None or v is None:
raise ValueError("Both k and v should be None or not None")
else:
if save_kv_cache:
cache_loc = (
forward_batch.out_cache_loc
Expand Down Expand Up @@ -787,7 +838,7 @@ def forward_decode(
page_table=metadata.encoder_page_table,
cache_seqlens=metadata.encoder_lens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=False,
Expand Down Expand Up @@ -817,7 +868,24 @@ def forward_decode(
**kwargs,
)
else:
is_swa_layer = (
layer.sliding_window_size is not None
and layer.sliding_window_size > -1
)

page_table = metadata.page_table
# For SWA layers on hybrid models, use the translated
# SWA-pool page table so KV reads hit the correct pool.
if is_swa_layer and self.use_sliding_window_kv_pool:
if metadata.swa_page_table is not None:
page_table = metadata.swa_page_table
else:
page_table = (
self.token_to_kv_pool.translate_loc_from_full_to_swa(
metadata.page_table
)
)

cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
Expand All @@ -833,7 +901,7 @@ def forward_decode(
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
Expand All @@ -854,7 +922,7 @@ def forward_decode(
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
cu_seqlens_k_new=None,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
Expand Down Expand Up @@ -899,7 +967,8 @@ def forward_decode(
layer.scaling,
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
out = o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return out

def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
Expand Down
73 changes: 71 additions & 2 deletions python/sglang/srt/layers/gemma4_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ def gemma_qkv_rmsnorm(

If k and v are both None (KV-shared layer), only Q is normalized.
"""
assert q.is_cuda
assert q.is_cuda or q.is_xpu
assert q.stride(-1) == 1, "Q's last dim must be contiguous"
assert q_weight.shape[-1] == head_dim
M = q.shape[0] if q.dim() >= 2 else 1
BLOCK = triton.next_power_of_2(head_dim)

has_kv = k is not None and v is not None
if has_kv:
assert k.is_cuda and v.is_cuda
assert (k.is_cuda and v.is_cuda) or (k.is_xpu and v.is_xpu)
assert k.stride(-1) == 1 and v.stride(-1) == 1
assert k_weight is not None and k_weight.shape[-1] == head_dim

Expand All @@ -245,6 +245,75 @@ def gemma_qkv_rmsnorm(
)


@triton.jit
def _gemma_routing_post_topk_kernel(
Logits_ptr,
Ids_ptr,
Scale_ptr,
Out_weights_ptr,
Out_ids_ptr,
stride_l,
stride_ow,
stride_oi,
K: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Fused: softmax(topk_logits) * per_expert_scale[topk_ids] → float32 weights, int32 ids.

One program per token. K is the number of top-k experts (e.g. 8).
"""
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_K)
mask = cols < K

logits = tl.load(
Logits_ptr + row * stride_l + cols, mask=mask, other=float("-inf")
).to(tl.float32)
ids_i64 = tl.load(Ids_ptr + row * stride_l + cols, mask=mask, other=0)

# Stable softmax
max_val = tl.max(logits, axis=0)
exp_val = tl.exp(logits - max_val)
sum_exp = tl.sum(exp_val, axis=0)
weights = exp_val / sum_exp

# Gather per_expert_scale and multiply
scale = tl.load(Scale_ptr + ids_i64, mask=mask, other=1.0).to(tl.float32)
weights = weights * scale

tl.store(Out_weights_ptr + row * stride_ow + cols, weights, mask=mask)
tl.store(Out_ids_ptr + row * stride_oi + cols, ids_i64.to(tl.int32), mask=mask)


def gemma_routing_post_topk(
topk_logits: torch.Tensor,
topk_ids: torch.Tensor,
per_expert_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fused softmax + scale-gather + casts for Gemma4 routing.

Replaces: softmax(topk_logits) * per_expert_scale[topk_ids] → (f32, i32).
"""
B, K = topk_logits.shape
BLOCK_K = triton.next_power_of_2(K)
out_weights = torch.empty((B, K), dtype=torch.float32, device=topk_logits.device)
out_ids = torch.empty((B, K), dtype=torch.int32, device=topk_logits.device)

_gemma_routing_post_topk_kernel[(B,)](
topk_logits,
topk_ids,
per_expert_scale,
out_weights,
out_ids,
topk_logits.stride(0),
out_weights.stride(0),
out_ids.stride(0),
K=K,
BLOCK_K=BLOCK_K,
)
return out_weights, out_ids


def gemma_dual_rmsnorm_residual_scalar(
x1: torch.Tensor,
weight1: torch.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = out.reshape(original_shape)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return x
if self.with_scale and self.scale_shift == 1.0:
out = gemma_rmsnorm(x, self.weight.data, self.eps)
else:
out = rmsnorm(x, self.weight.data, self.eps)
return out

def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
# sgl_kernel's gemma_rmsnorm is not available on ROCm;
# delegate to the pure-PyTorch implementation.
Expand Down
Loading
Loading