diff --git a/python/sglang/srt/layers/attention/xpu_backend.py b/python/sglang/srt/layers/attention/xpu_backend.py index 3b5743799dfb..46facf8f156e 100644 --- a/python/sglang/srt/layers/attention/xpu_backend.py +++ b/python/sglang/srt/layers/attention/xpu_backend.py @@ -13,6 +13,7 @@ prepare_swa_spec_page_table_triton, ) from sglang.srt.managers.schedule_batch import get_global_server_args +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: @@ -72,6 +73,12 @@ def __init__( self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.skip_prefill = skip_prefill self.is_hybrid_swa = model_runner.is_hybrid_swa + self.use_sliding_window_kv_pool = ( + isinstance(model_runner.token_to_kv_pool, SWAKVPool) + and model_runner.token_to_kv_pool.swa_layer_nums > 0 + ) + if self.use_sliding_window_kv_pool: + self.token_to_kv_pool = model_runner.token_to_kv_pool if self.is_hybrid_swa: self.full_to_swa_index_mapping = ( model_runner.token_to_kv_pool.full_to_swa_index_mapping @@ -193,6 +200,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): metadata.page_table = self.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + # TODO: we need to test this part for llama 4 eagle case self._init_local_attn_metadata(forward_batch, metadata, device) elif forward_batch.forward_mode.is_target_verify(): @@ -373,6 +381,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ), ] + # Translate full-pool indices to SWA-pool indices for hybrid models + if self.use_sliding_window_kv_pool: + metadata.swa_page_table = ( + self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) + ) + if self.use_mla: workspace_size = flash_mla_get_workspace_size( max_seq_len=self.max_context_len, @@ -389,11 +405,25 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): workspace_size, device=self.device, dtype=torch.uint8 ) + # Translate full-pool indices to SWA-pool indices for hybrid models + if self.use_sliding_window_kv_pool: + metadata.swa_page_table = ( + self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) + ) + # Convert the page table to a strided format which is needed by FA3 API if self.page_size > 1: self.strided_indices = torch.arange( 0, metadata.page_table.shape[1], self.page_size, device=self.device ) + + if self.use_sliding_window_kv_pool and metadata.swa_page_table is not None: + metadata.swa_page_table = ( + metadata.swa_page_table[:, self.strided_indices] // self.page_size + ) + metadata.page_table = ( metadata.page_table[:, self.strided_indices] // self.page_size ) @@ -413,8 +443,17 @@ def forward_extend( k_rope: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ): - if k is not None: - assert v is not None + if k is None and v is None: + # Cross-layer KV sharing (Gemma 4): the layer reuses another + # layer's KV cache. The paged kernel reads K/V directly via + # page_table, and pool.get_kv_buffer(layer.layer_id) routes + # to the correct sub-pool because RadixAttention is initialized + # with layer_id=kv_shared_layer_index for shared layers. No + # materialization needed; just skip the write path. + pass + elif k is None or v is None: + raise ValueError("Both k and v should be None or not None") + else: if save_kv_cache: cache_loc = ( forward_batch.out_cache_loc @@ -497,6 +536,13 @@ def forward_extend( cu_seqlens_k = swa_spec_metadata.cu_seqlens_k else: page_table = metadata.page_table + if is_hybrid_swa and self.use_sliding_window_kv_pool: + if metadata.swa_page_table is not None: + page_table = metadata.swa_page_table + else: + page_table = self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) cu_seqlens_q = metadata.cu_seqlens_q cache_seqlens = metadata.cache_seqlens_int32 max_seqlen_q = metadata.max_seq_len_q @@ -525,7 +571,7 @@ def forward_extend( page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, causal=False if use_cascade_attn else causal, @@ -546,7 +592,7 @@ def forward_extend( page_table=self.forward_metadata_spec_decode_expand.page_table, cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, - cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, softmax_scale=layer.scaling, causal=False, @@ -648,7 +694,7 @@ def forward_extend( page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, + cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, causal=False if use_cascade_attn else causal, @@ -668,7 +714,7 @@ def forward_extend( page_table=self.forward_metadata_spec_decode_expand.page_table, cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, - cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, softmax_scale=layer.scaling, causal=False, @@ -688,7 +734,8 @@ def forward_extend( else: o = result - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + out = o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + return out def forward_decode( self, @@ -703,8 +750,12 @@ def forward_decode( k_rope: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if k is not None: - assert v is not None + if k is None and v is None: + # Cross-layer KV sharing (Gemma 4): see forward_extend for details. + pass + elif k is None or v is None: + raise ValueError("Both k and v should be None or not None") + else: if save_kv_cache: cache_loc = ( forward_batch.out_cache_loc @@ -787,7 +838,7 @@ def forward_decode( page_table=metadata.encoder_page_table, cache_seqlens=metadata.encoder_lens_int32, cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=metadata.encoder_cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=1, softmax_scale=layer.scaling, causal=False, @@ -817,7 +868,24 @@ def forward_decode( **kwargs, ) else: + is_swa_layer = ( + layer.sliding_window_size is not None + and layer.sliding_window_size > -1 + ) + page_table = metadata.page_table + # For SWA layers on hybrid models, use the translated + # SWA-pool page table so KV reads hit the correct pool. + if is_swa_layer and self.use_sliding_window_kv_pool: + if metadata.swa_page_table is not None: + page_table = metadata.swa_page_table + else: + page_table = ( + self.token_to_kv_pool.translate_loc_from_full_to_swa( + metadata.page_table + ) + ) + cache_seqlens = metadata.cache_seqlens_int32 cu_seqlens_k = metadata.cu_seqlens_k max_seqlen_q = metadata.max_seq_len_q @@ -833,7 +901,7 @@ def forward_decode( page_table=page_table, cache_seqlens=cache_seqlens, cu_seqlens_q=metadata.cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, causal=False if use_cascade_attn else causal, @@ -854,7 +922,7 @@ def forward_decode( page_table=self.forward_metadata_spec_decode_expand.page_table, cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, - cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, + cu_seqlens_k_new=None, max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, softmax_scale=layer.scaling, causal=False, @@ -899,7 +967,8 @@ def forward_decode( layer.scaling, ) - return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + out = o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + return out def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for sequence length in CUDA graph.""" diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index ad6f01d9875a..3c1f56ffc90c 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -215,7 +215,7 @@ def gemma_qkv_rmsnorm( If k and v are both None (KV-shared layer), only Q is normalized. """ - assert q.is_cuda + assert q.is_cuda or q.is_xpu assert q.stride(-1) == 1, "Q's last dim must be contiguous" assert q_weight.shape[-1] == head_dim M = q.shape[0] if q.dim() >= 2 else 1 @@ -223,7 +223,7 @@ def gemma_qkv_rmsnorm( has_kv = k is not None and v is not None if has_kv: - assert k.is_cuda and v.is_cuda + assert (k.is_cuda and v.is_cuda) or (k.is_xpu and v.is_xpu) assert k.stride(-1) == 1 and v.stride(-1) == 1 assert k_weight is not None and k_weight.shape[-1] == head_dim @@ -245,6 +245,75 @@ def gemma_qkv_rmsnorm( ) +@triton.jit +def _gemma_routing_post_topk_kernel( + Logits_ptr, + Ids_ptr, + Scale_ptr, + Out_weights_ptr, + Out_ids_ptr, + stride_l, + stride_ow, + stride_oi, + K: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Fused: softmax(topk_logits) * per_expert_scale[topk_ids] → float32 weights, int32 ids. + + One program per token. K is the number of top-k experts (e.g. 8). + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_K) + mask = cols < K + + logits = tl.load( + Logits_ptr + row * stride_l + cols, mask=mask, other=float("-inf") + ).to(tl.float32) + ids_i64 = tl.load(Ids_ptr + row * stride_l + cols, mask=mask, other=0) + + # Stable softmax + max_val = tl.max(logits, axis=0) + exp_val = tl.exp(logits - max_val) + sum_exp = tl.sum(exp_val, axis=0) + weights = exp_val / sum_exp + + # Gather per_expert_scale and multiply + scale = tl.load(Scale_ptr + ids_i64, mask=mask, other=1.0).to(tl.float32) + weights = weights * scale + + tl.store(Out_weights_ptr + row * stride_ow + cols, weights, mask=mask) + tl.store(Out_ids_ptr + row * stride_oi + cols, ids_i64.to(tl.int32), mask=mask) + + +def gemma_routing_post_topk( + topk_logits: torch.Tensor, + topk_ids: torch.Tensor, + per_expert_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused softmax + scale-gather + casts for Gemma4 routing. + + Replaces: softmax(topk_logits) * per_expert_scale[topk_ids] → (f32, i32). + """ + B, K = topk_logits.shape + BLOCK_K = triton.next_power_of_2(K) + out_weights = torch.empty((B, K), dtype=torch.float32, device=topk_logits.device) + out_ids = torch.empty((B, K), dtype=torch.int32, device=topk_logits.device) + + _gemma_routing_post_topk_kernel[(B,)]( + topk_logits, + topk_ids, + per_expert_scale, + out_weights, + out_ids, + topk_logits.stride(0), + out_weights.stride(0), + out_ids.stride(0), + K=K, + BLOCK_K=BLOCK_K, + ) + return out_weights, out_ids + + def gemma_dual_rmsnorm_residual_scalar( x1: torch.Tensor, weight1: torch.Tensor, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 951d704f7ec6..6cf269b8d3ca 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -878,6 +878,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = out.reshape(original_shape) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + if self.with_scale and self.scale_shift == 1.0: + out = gemma_rmsnorm(x, self.weight.data, self.eps) + else: + out = rmsnorm(x, self.weight.data, self.eps) + return out + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: # sgl_kernel's gemma_rmsnorm is not available on ROCm; # delegate to the pure-PyTorch implementation. diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index c406f12a2b6c..50db5dafdef4 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -33,6 +33,7 @@ gemma_dual_rmsnorm_residual_scalar, gemma_qkv_rmsnorm, gemma_rmsnorm_residual_scalar, + gemma_routing_post_topk, ) from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm from sglang.srt.layers.linear import ( @@ -55,6 +56,9 @@ maybe_remap_kv_scale_name, ) from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding +from sglang.srt.models.utils import ( + create_fused_set_kv_buffer_arg, +) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers @@ -144,7 +148,8 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size - # RMSNorm without learned weight — pure normalization only + # RMSNorm without learned weight — scale is folded into norm weight + # after loading so forward is a single fused norm kernel. self.norm = Gemma4RMSNorm( self.hidden_size, eps=config.rms_norm_eps, with_scale=False ) @@ -164,18 +169,19 @@ def __init__( quant_config=None, prefix=add_prefix("proj", prefix), ) - self._fused_scale: Optional[torch.Tensor] = None + self._scale_fused = False def fuse_scale(self): - """Pre-compute scale * root_size. Call after weights are loaded.""" - self._fused_scale = (self.scale * self.root_size).to(self.scale.dtype) + """Fold scale * root_size into norm.weight so forward needs no extra mul.""" + fused = (self.scale * self.root_size).to(self.norm.weight.dtype) + self.norm.weight.data.copy_(fused) + self._scale_fused = True def forward(self, x: torch.Tensor) -> torch.Tensor: """Returns raw router logits [T, E].""" - x = self.norm(x) - if self._fused_scale is None: + if not self._scale_fused: self.fuse_scale() - x = x * self._fused_scale.to(x.dtype) + x = self.norm(x) router_logits, _ = self.proj(x) return router_logits @@ -221,13 +227,15 @@ def routing_function( # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), # so we softmax only the top-k logits (fewer kernel launches). topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) - topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) - # Fold per_expert_scale into routing weights + # Fused: softmax + per_expert_scale gather + mul + casts in one kernel + if topk_logits.is_cuda or topk_logits.is_xpu: + return gemma_routing_post_topk(topk_logits, topk_ids, per_expert_scale) + + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) topk_weights = topk_weights * per_expert_scale[topk_ids].to( topk_weights.dtype ) - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) self.topk = TopK( @@ -398,14 +406,15 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Fused Q/K/V RMSNorm: replaces three separate norm kernels with one. - # Preconditions for the fused path: tensors on CUDA, q_norm/k_norm use - # the standard norm*weight (scale_shift==0) and v_norm has weight=ones + # Preconditions for the fused path: tensors on CUDA or XPU (the kernel + # is pure Triton and lowers to both backends), q_norm/k_norm use the + # standard norm*weight (scale_shift==0) and v_norm has weight=ones # (with_scale=False) — the canonical Gemma4 attention configuration. is_kv_shared = ( self.is_kv_shared_layer and self.kv_shared_layer_index is not None ) can_fuse_qkv_norm = ( - q.is_cuda + (q.is_cuda or q.is_xpu) and self.q_norm.scale_shift == 0.0 and self.k_norm.scale_shift == 0.0 and not self.v_norm.with_scale @@ -457,9 +466,22 @@ def forward( v = self.v_norm(v) # Apply rotary embedding + use_fused_kv = False if k is not None: k = k.flatten(-2, -1) - q, k = self.rotary_emb(positions, q, k) + # Fuse RoPE + KV-cache write for non-SWA layers with bf16 cache + # DISABLED: causes accuracy regression in launch_server path + can_fuse = False + if can_fuse: + fused_arg = create_fused_set_kv_buffer_arg( + value=v.flatten(-2, -1) if v.dim() == 3 else v, + layer=self.attn, + forward_batch=forward_batch, + ) + use_fused_kv = True + else: + fused_arg = None + q, k = self.rotary_emb(positions, q, k, fused_set_kv_buffer_arg=fused_arg) k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) else: # Rotary embedding requires a key input; use zeros since KV is shared from another layer @@ -472,7 +494,7 @@ def forward( k, v, forward_batch=forward_batch, - save_kv_cache=not self.is_kv_shared_layer, + save_kv_cache=not self.is_kv_shared_layer and not use_fused_kv, ) if attn_output.dim() == 3: attn_output = attn_output.flatten(-2, -1) @@ -657,7 +679,7 @@ def forward( # Fused: (rmsnorm(rmsnorm(h1,w1) + rmsnorm(h2,w2), w3) + residual) * scalar if ( not self.has_ple - and hidden_states_1.is_cuda + and (hidden_states_1.is_cuda or hidden_states_1.is_xpu) and hidden_states_1.dim() == 2 ): norm1 = self.post_feedforward_layernorm_1 @@ -689,7 +711,12 @@ def forward( ) hidden_states = self.mlp(hidden_states) - if not self.has_ple and hidden_states.is_cuda and hidden_states.dim() == 2: + if ( + not self.has_ple + and self.moe is None + and (hidden_states.is_cuda or hidden_states.is_xpu) + and hidden_states.dim() == 2 + ): # Fused: (post_ff_norm(h) + residual) * layer_scalar in one kernel norm = self.post_feedforward_layernorm hidden_states = gemma_rmsnorm_residual_scalar( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9d27b33668bb..7cef7905ed9f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2286,12 +2286,12 @@ def _handle_model_specific_adjustments(self): self.attention_backend = default_attention_backend prefill_backend, decode_backend = self.get_attention_backends() - accepted_backends = ("trtllm_mha", "triton") + accepted_backends = ("trtllm_mha", "triton", "intel_xpu") assert ( prefill_backend in accepted_backends and decode_backend in accepted_backends ), ( - "Gemma4 only supports trtllm_mha or triton attention backend, " + "Gemma4 only supports trtllm_mha, triton, or intel_xpu attention backend, " f"got prefill={prefill_backend}, decode={decode_backend}" ) diff --git a/test/registered/xpu/gemma4_chat_template.jinja b/test/registered/xpu/gemma4_chat_template.jinja new file mode 100644 index 000000000000..518b59f9a976 --- /dev/null +++ b/test/registered/xpu/gemma4_chat_template.jinja @@ -0,0 +1,5 @@ +{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}user +{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for part in message['content'] %}{% if part['type'] == 'text' %}{{ part['text'] }}{% endif %}{% endfor %}{% endif %} +{% elif message['role'] == 'model' or message['role'] == 'assistant' %}model +{{ message['content'] }} +{% endif %}{% endfor %}model diff --git a/test/registered/xpu/test_gemma_4_e2b.py b/test/registered/xpu/test_gemma_4_e2b.py new file mode 100644 index 000000000000..fcbd29f9c6c5 --- /dev/null +++ b/test/registered/xpu/test_gemma_4_e2b.py @@ -0,0 +1,313 @@ +""" +Gemma 4 E2B: simple text Q&A on XPU (OpenAI /v1), same shape as +``test_deepseek_coder_v2_lite_instruct.py``. + +Model card: https://huggingface.co/google/gemma-4-E2B + + - XPU test runs when Intel XPU is available. + +Run from test/srt:: + + python3 -m unittest xpu.test_gemma_4_e2b.TestGemma4E2BXPU.test_simple_code_qa + +Appends to ``gemma_4_e2b_comparison.txt`` in this directory. + +Tensor dumps (all layers) go under ``debug_tensor_dump_output/gemma_4_e2b/``. + +Server is started with ``sglang serve`` (``--model-impl sglang``). +""" + +from __future__ import annotations + +import os +import re +import unittest +from datetime import datetime, timezone +from pathlib import Path + +import openai + +from sglang.srt.utils.common import is_xpu +from sglang.test.test_utils import CustomTestCase +from sglang.test.vlm_utils import ( + DEFAULT_URL_FOR_TEST, + kill_process_tree, + popen_launch_server, +) + +MODEL = "google/gemma-4-E2B" + +COMPARISON_LOG_PATH = Path(__file__).resolve().parent / "gemma_4_e2b_comparison.txt" +DEBUG_TENSOR_DUMP_OUTPUT_DIR = ( + Path(__file__).resolve().parent / "debug_tensor_dump_output" / "gemma_4_e2b" +) +LAUNCH_TIMEOUT = 900 + + +def _server_subprocess_env() -> dict: + return { + "TORCHDYNAMO_VERBOSE": "0", + "TORCHINDUCTOR_VERBOSE": "0", + "TORCH_COMPILE_DEBUG": "0", + "TORCH_SHOW_CPP_STACKTRACES": "0", + } + + +def _prettify_spm_style_text(s: str) -> str: + """Turn SentencePiece-style space/newline markers in API strings into normal text.""" + if not s: + return s + return s.replace("\u010a", "\n").replace("\u0120", " ") + + +def setUpModule(): + DEBUG_TENSOR_DUMP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + for _pt in DEBUG_TENSOR_DUMP_OUTPUT_DIR.rglob("*.pt"): + _pt.unlink(missing_ok=True) + COMPARISON_LOG_PATH.write_text( + "Gemma-4-E2B \u2014 device comparison log\n" + f"Model: {MODEL}\n" + f"Run started (UTC): {datetime.now(timezone.utc).isoformat()}\n" + f"{'=' * 80}\n\n", + encoding="utf-8", + ) + + +def _append_comparison_log( + *, + title: str, + device_cli: str, + extra_server_notes: str, + user_prompt: str, + response, +) -> None: + msg = response.choices[0].message + content = _prettify_spm_style_text(msg.content or "") + reasoning = _prettify_spm_style_text(getattr(msg, "reasoning_content", None) or "") + usage = response.usage + block = ( + f"\n{'#' * 80}\n" + f"{title}\n" + f"Server device flag: {device_cli}\n" + f"{extra_server_notes}\n" + f"{'#' * 80}\n" + f"--- user prompt ---\n{user_prompt}\n" + f"--- assistant message.content ---\n{content}\n" + f"--- assistant message.reasoning_content (if any) ---\n{reasoning}\n" + f"--- usage ---\n" + f" prompt_tokens: {getattr(usage, 'prompt_tokens', None)}\n" + f" completion_tokens: {getattr(usage, 'completion_tokens', None)}\n" + f" total_tokens: {getattr(usage, 'total_tokens', None)}\n" + f"{'=' * 80}\n" + ) + with COMPARISON_LOG_PATH.open("a", encoding="utf-8") as f: + f.write(block) + + +# Gemma 4 E2B does not ship a chat_template in its tokenizer. +# Provide a minimal Gemma-style Jinja2 template file. +_CHAT_TEMPLATE_PATH = str( + Path(__file__).resolve().parent / "gemma4_chat_template.jinja" +) + +# E2B model: single-rank for small model on XPU. +XPU_SERVER_ARGS = [ + "--device", + "xpu", + "--tp=1", + "--trust-remote-code", + "--disable-overlap-schedule", + "--page-size", + "64", + "--attention-backend", + "intel_xpu", + "--model-impl", + "sglang", + "--chat-template", + _CHAT_TEMPLATE_PATH, +] + +_SIMPLE_CODE_PROMPT = ( + "Write a minimal Python function `def add(a, b):` that returns a+b. " + "Reply with only the function, give a brief explanation. " + "Finish with asking me How can I help you today?" +) + + +def _simple_text_messages(): + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": _SIMPLE_CODE_PROMPT}, + ], + } + ] + + +def _compact_code_text(s: str) -> str: + t = s.replace("\u0120", " ").replace("\u010a", "\n") + return re.sub(r"\s+", "", t.lower()) + + +def _assert_code_reply(response): + assert response.choices[0].message.role == "assistant" + msg = response.choices[0].message + text = msg.content or "" + reasoning = getattr(msg, "reasoning_content", None) or "" + combined = f"{text} {reasoning}".strip() + assert len(combined) > 0 + lower = combined.lower() + assert ( + "def" in lower and "add" in lower + ), f"expected a Python `def add` in reply, got: {combined!r}" + assert "return" in lower, f"expected `return` in reply, got: {combined!r}" + compact = _compact_code_text(combined) + assert ( + "a+b" in compact + ), f"expected `a+b` (allowing spaces) in reply, got: {combined!r}" + assert response.usage is not None + assert response.usage.completion_tokens > 0 + + +@unittest.skipUnless(is_xpu(), "Intel XPU not available") +class TestGemma4E2BXPU(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + os.environ["SGLANG_USE_SGL_XPU"] = "1" + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=LAUNCH_TIMEOUT, + api_key=cls.api_key, + other_args=list(XPU_SERVER_ARGS), + device="cuda", + env=_server_subprocess_env(), + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_simple_code_qa(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model="default", + messages=_simple_text_messages(), + temperature=0, + max_tokens=96, + ) + _assert_code_reply(response) + _append_comparison_log( + title="OUTPUT FROM --device XPU (Gemma-4-E2B)", + device_cli="--device xpu", + extra_server_notes="SGLANG_USE_SGL_XPU=1; see XPU_SERVER_ARGS in test source.", + user_prompt=_SIMPLE_CODE_PROMPT, + response=response, + ) + + def test_sliding_window_long_context(self): + """Generate >511 tokens to exercise decode past the SWA window boundary. + + Gemma 4 E2B has sliding_window=512 (511 in SGLang exclusive). + With ~30 prompt tokens + 600 generated tokens, the total sequence + (~630 tokens) exceeds the window, forcing the SWA decode kernel + to actually mask out-of-window tokens. If page table translation + or kernel masking is broken, this test will produce garbage or crash. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": ( + "Write a detailed, step-by-step tutorial on how to " + "build a simple web server in Python using the socket " + "module. Include complete code examples with comments." + ), + }, + ], + } + ], + temperature=0, + max_tokens=600, + ) + msg = response.choices[0].message + text = msg.content or "" + assert len(text) > 0, "expected non-empty response" + assert response.usage is not None + assert response.usage.completion_tokens >= 500, ( + f"expected >= 500 completion tokens to exceed SWA window, " + f"got {response.usage.completion_tokens}" + ) + _append_comparison_log( + title="OUTPUT FROM --device XPU (Gemma-4-E2B) [SWA long context]", + device_cli="--device xpu", + extra_server_notes="SWA window=511; generated >500 tokens to exceed window.", + user_prompt="(long context SWA test)", + response=response, + ) + + def test_sliding_window_3k_tokens(self): + """Generate ~3000 tokens — approximately 6x the SWA window. + + At 3000 tokens the sliding window (511) has rolled many times, + stressing the decode kernel's local masking and KV cache page + table management over an extended generation. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": ( + "Write a comprehensive guide to Python data structures. " + "Cover lists, tuples, dictionaries, sets, double-ended queues, " + "namedtuples, dataclasses, and custom linked lists. " + "For each one, provide multiple code examples, explain " + "time complexity of common operations, compare trade-offs, " + "and show real-world use cases. Be extremely thorough." + ), + }, + ], + } + ], + temperature=0, + max_tokens=3000, + ) + msg = response.choices[0].message + text = msg.content or "" + assert len(text) > 0, "expected non-empty response" + assert response.usage is not None + assert response.usage.completion_tokens >= 2500, ( + f"expected >= 2500 completion tokens (6x SWA window), " + f"got {response.usage.completion_tokens}" + ) + _append_comparison_log( + title="OUTPUT FROM --device XPU (Gemma-4-E2B) [SWA 3k tokens]", + device_cli="--device xpu", + extra_server_notes="SWA window=511; generated ~3000 tokens (6x window).", + user_prompt="(3k token SWA stress test)", + response=response, + ) + + +from sglang.test.ci.ci_register import register_xpu_ci + +register_xpu_ci(est_time=360, suite="stage-b-test-1-gpu-xpu") + +if __name__ == "__main__": + unittest.main()