diff --git a/.codespellrc b/.codespellrc index 808a344b4e6f..5b14597698f4 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, ather skip = *.json,*.jsonl,*.patch,*.txt diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index 37a9607b6014..d08d2bb75d83 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -134,6 +134,10 @@ def get_model_config( topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size hidden_size = getattr(config, "moe_latent_size", None) or hidden_size + elif architecture == "Gemma4ForConditionalGeneration": + E = config.num_experts // ep_size + topk = config.top_k_experts + intermediate_size = config.moe_intermediate_size else: # Default: Mixtral E = config.num_local_experts // ep_size diff --git a/benchmark/mmlu/bench_hf.py b/benchmark/mmlu/bench_hf.py new file mode 100644 index 000000000000..c76a18db685b --- /dev/null +++ b/benchmark/mmlu/bench_hf.py @@ -0,0 +1,151 @@ +""" +Usage: +python3 bench_hf.py --model-path meta-llama/Llama-2-7b-hf --data-dir data --ntrain 5 +""" + +import argparse +import json +import os +import time + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def main(args): + print(f"Loading model: {args.model_path}") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(args.data_dir, "test")) + if "_test.csv" in f + ] + ) + + all_cors = [] + num_requests = 0 + total_latency = 0 + + for subject in tqdm(subjects[: args.nsub]): + dev_df = pd.read_csv( + os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None + )[: args.ntrain] + test_df = pd.read_csv( + os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None + ) + + k = args.ntrain + few_shot_examples = gen_prompt(dev_df, subject, k) + while len(tokenizer.encode(few_shot_examples)) > 1536: + k -= 1 + if k < 0: + break + few_shot_examples = gen_prompt(dev_df, subject, k) + + preds = [] + labels = [] + tic = time.perf_counter() + + for i in range(test_df.shape[0]): + prompt_end = format_example(test_df, i, include_answer=False) + prompt = few_shot_examples + prompt_end + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + output_ids = model.generate( + input_ids, + max_new_tokens=1, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + output_str = tokenizer.decode( + output_ids[0][input_ids.shape[-1] :], skip_special_tokens=True + ) + preds.append(output_str.strip()[0] if len(output_str.strip()) > 0 else "") + labels.append(test_df.iloc[i, test_df.shape[1] - 1]) + + latency = time.perf_counter() - tic + total_latency += latency + + cors = [pred == label for pred, label in zip(preds, labels)] + all_cors.append(cors) + num_requests += len(test_df) + + print( + f"Subject: {subject}, Accuracy: {np.mean(cors):.3f}, Latency: {latency:.3f}s" + ) + + weighted_acc = np.mean(np.concatenate(all_cors)) + print(f"Total Latency: {total_latency:.3f}s") + print(f"Average Accuracy: {weighted_acc:.3f}") + + if args.output: + with open(args.output, "a") as fout: + value = { + "task": "mmlu", + "backend": "hf", + "model": args.model_path, + "latency": round(total_latency, 3), + "accuracy": round(weighted_acc, 3), + "num_requests": num_requests, + "other": { + "nsub": args.nsub, + "ntrain": args.ntrain, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--ntrain", type=int, default=5) + parser.add_argument("--data-dir", type=str, default="data") + parser.add_argument("--nsub", type=int, default=60) + parser.add_argument("--output", type=str, help="Output file path") + args = parser.parse_args() + main(args) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 850270308d2b..86a0fc15b287 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -404,6 +404,19 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="gemma-4-it", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("<|turn>user\n", "\n"), + "assistant": ("<|turn>assistant\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + register_chat_template( ChatTemplate( name="dbrx-instruct", @@ -611,8 +624,10 @@ def match_chat_yi(model_path: str): @register_chat_template_matching_function -def match_gemma_it(model_path: str): - if re.search(r"gemma.*it", model_path, re.IGNORECASE): +def match_gemma(model_path: str): + if re.search(r"gemma-4.*it", model_path, re.IGNORECASE): + return "gemma-4-it" + if re.search(r"(gemma.*it)|(gemma-3)", model_path, re.IGNORECASE): return "gemma-it" @@ -636,12 +651,6 @@ def match_granite_instruct(model_path: str): return "granite-3-instruct" -@register_chat_template_matching_function -def match_gemma3_instruct(model_path: str): - if re.search(r"gemma-3", model_path, re.IGNORECASE): - return "gemma-it" - - @register_chat_template_matching_function def match_internvl_chat(model_path: str): if re.search(r"internvl2_5", model_path, re.IGNORECASE): diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 89e90516ef12..3504fe7b1bac 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -376,6 +376,8 @@ def _derive_hybrid_model(self): self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [ "MiMoV2FlashForCausalLM", "MiMoV2MTP", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", ] def _derive_context_length(self, context_length: int): @@ -433,7 +435,7 @@ def _derive_model_shapes(self): self.swa_v_head_dim = getattr( self.hf_text_config, "swa_v_head_dim", - self.v_head_dim, + self.swa_head_dim, ) # FIXME: temporary special judge for MLA architecture if ( @@ -1301,6 +1303,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Ernie4_5_VLMoeForConditionalGeneration", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", "Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration", @@ -1447,6 +1450,8 @@ def is_hybrid_swa_model(model_architectures: List[str]): "MiMoV2MTP", "Step3p5ForCausalLM", "Step3p5MTP", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", } return any(arch in hybrid_swa_archs for arch in model_architectures) @@ -1464,7 +1469,7 @@ def get_hybrid_layer_ids( i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 ] elif "GptOssForCausalLM" in model_architectures: - layer_types = getattr(hf_text_config, "layer_types", None) + layer_types = getattr(hf_text_config, "layer_types", []) swa_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "sliding_attention" ] @@ -1497,6 +1502,17 @@ def get_hybrid_layer_ids( elif "Step3p5MTP" in model_architectures: swa_attention_layer_ids = [0] full_attention_layer_ids = [] + elif ( + "Gemma4ForCausalLM" in model_architectures + or "Gemma4ForConditionalGeneration" in model_architectures + ): + layer_types = getattr(hf_text_config, "layer_types", []) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] else: swa_attention_layer_ids = None full_attention_layer_ids = None diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d0d05e3a483e..b3e5b62dc0d2 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -120,6 +120,11 @@ def __init__( and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" ) + self.is_gemma4 = ( + hasattr(self.tokenizer_manager.model_config, "hf_config") + and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") + and self.tokenizer_manager.model_config.hf_config.model_type == "gemma4" + ) self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() @@ -331,7 +336,7 @@ def _process_messages( ) -> MessageProcessingResult: """Process chat messages and apply chat template""" # GptOss model needs to keep special tokens for harmony parsing - if self.is_gpt_oss: + if self.is_gpt_oss or self.is_gemma4: request.skip_special_tokens = False self._patch_mistral_skip_special_tokens(request) @@ -1280,12 +1285,18 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """ if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3"]: + + if self.reasoning_parser == "deepseek-v3": # Models that require explicit enable thinking (thinking=True) return ( request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True ) + if self.reasoning_parser == "gemma4": + return ( + request.chat_template_kwargs is not None + and request.chat_template_kwargs.get("enable_thinking") is True + ) if self.reasoning_parser in ["kimi_k2"]: # Models that thinking by default, and can be disabled by setting thinking=False return ( diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index ca066e196d0f..84196d8cb057 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,6 +14,7 @@ from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import Gemma4Detector from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -69,6 +70,7 @@ class FunctionCallParser: "interns1": InternlmDetector, "hermes": HermesDetector, "gigachat3": GigaChat3Detector, + "gemma4": Gemma4Detector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py new file mode 100644 index 000000000000..2b4b9e05a16b --- /dev/null +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -0,0 +1,445 @@ +import json +import logging +from typing import List, Optional + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + +# Gemma4 special tokens for tool calls +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" +STRING_DELIM = '<|"|>' + + +def _parse_gemma4_value(value_str: str) -> object: + """Parse a single Gemma4 value (after key:) into a Python object.""" + value_str = value_str.strip() + if not value_str: + return value_str + + # Boolean + if value_str == "true": + return True + if value_str == "false": + return False + + # Number (int or float) + try: + if "." in value_str: + return float(value_str) + return int(value_str) + except ValueError: + pass + + # Bare string (no <|"|> delimiters) + return value_str + + +def _parse_gemma4_array(arr_str: str) -> list: + """Parse a Gemma4 array content string into a Python list.""" + items: list = [] + i = 0 + n = len(arr_str) + + while i < n: + while i < n and arr_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # String element + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + end_pos = arr_str.find(STRING_DELIM, i) + if end_pos == -1: + items.append(arr_str[i:]) + break + items.append(arr_str[i:end_pos]) + i = end_pos + len(STRING_DELIM) + + # Nested object + elif arr_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + next_delim = arr_str.find(STRING_DELIM, i) + i = next_delim + len(STRING_DELIM) if next_delim != -1 else n + continue + if arr_str[i] == "{": + depth += 1 + elif arr_str[i] == "}": + depth -= 1 + i += 1 + items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) + + # Nested array + elif arr_str[i] == "[": + depth = 1 + sub_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i] == "[": + depth += 1 + elif arr_str[i] == "]": + depth -= 1 + i += 1 + items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) + + # Bare value + else: + val_start = i + while i < n and arr_str[i] not in (",", "]"): + i += 1 + items.append(_parse_gemma4_value(arr_str[val_start:i])) + + return items + + +def _parse_gemma4_args(args_str: str) -> dict: + """Parse Gemma4's custom key:value format into a Python dict.""" + if not args_str or not args_str.strip(): + return {} + + result: dict = {} + i = 0 + n = len(args_str) + + while i < n: + # Skip whitespace and commas + while i < n and args_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # Parse key (unquoted, ends at ':') + key_start = i + while i < n and args_str[i] != ":": + i += 1 + if i >= n: + break + key = args_str[key_start:i].strip() + i += 1 # skip ':' + + # Parse value + if i >= n: + result[key] = "" + break + + # Skip whitespace after ':' + while i < n and args_str[i] in (" ", "\n", "\t"): + i += 1 + if i >= n: + result[key] = "" + break + + # String value: <|"|>...<|"|> + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + val_start = i + end_pos = args_str.find(STRING_DELIM, i) + if end_pos == -1: + # Unterminated string — take rest + result[key] = args_str[val_start:] + break + result[key] = args_str[val_start:end_pos] + i = end_pos + len(STRING_DELIM) + + # Nested object: {...} + elif args_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + # Skip over string contents + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "{": + depth += 1 + elif args_str[i] == "}": + depth -= 1 + i += 1 + result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) + + # Array: [...] + elif args_str[i] == "[": + depth = 1 + arr_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "[": + depth += 1 + elif args_str[i] == "]": + depth -= 1 + i += 1 + arr_content = args_str[arr_start : i - 1] + result[key] = _parse_gemma4_array(arr_content) + + # Bare value (number, boolean, etc.) + else: + val_start = i + while i < n and args_str[i] not in (",", "}", "]"): + i += 1 + result[key] = _parse_gemma4_value(args_str[val_start:i]) + + return result + + +def _find_matching_brace(text: str) -> int: + """Find index of matching '}' in text, respecting STRING_DELIM and nesting. + + Assumes text starts just after the opening '{'. + Returns index of closing brace, or -1 if not found (incomplete). + """ + depth = 1 + i = 0 + n = len(text) + delim_len = len(STRING_DELIM) + while i < n and depth > 0: + if text[i : i + delim_len] == STRING_DELIM: + i += delim_len + next_delim = text.find(STRING_DELIM, i) + if next_delim == -1: + return -1 + i = next_delim + delim_len + continue + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + i += 1 + return (i - 1) if depth == 0 else -1 + + +class Gemma4Detector(BaseFormatDetector): + def __init__(self): + super().__init__() + self.tool_call_start_token = TOOL_CALL_START + self.tool_call_end_token = TOOL_CALL_END + + # Streaming state + self.parsed_pos: int = 0 + self.is_inside_tool_call: bool = False + self.current_func_name: Optional[str] = None + self._tool_indices: Optional[dict] = None + + @staticmethod + def _extract_tool_calls(text: str) -> list: + """Extract (func_name, args_str) pairs using brace-balanced parsing.""" + results = [] + search_from = 0 + while True: + start = text.find(TOOL_CALL_START, search_from) + if start == -1: + break + end = text.find(TOOL_CALL_END, start) + if end == -1: + break + inner = text[start + len(TOOL_CALL_START) : end] + if inner.startswith("call:"): + brace = inner.find("{") + if brace != -1: + func_name = inner[5:brace] + args_content = inner[brace + 1 :] + match_idx = _find_matching_brace(args_content) + args_str = ( + args_content[:match_idx] if match_idx != -1 else args_content + ) + results.append((func_name, args_str)) + search_from = end + len(TOOL_CALL_END) + return results + + def has_tool_call(self, text: str) -> bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + if self.tool_call_start_token not in text: + return StreamingParseResult(normal_text=text) + + calls = [] + try: + matches = self._extract_tool_calls(text) + if not matches: + return StreamingParseResult(normal_text=text) + + tool_indices = self._get_tool_indices(tools) + for func_name, args_str in matches: + arguments = _parse_gemma4_args(args_str) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(func_name, -1), + name=func_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + ) + + # Content = text before first tool call + content_end = text.find(self.tool_call_start_token) + normal_text = text[:content_end] if content_end > 0 else "" + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + self._buffer += new_text + + if not self._buffer: + return StreamingParseResult() + + calls = [] + normal_text_chunks = [] + if self._tool_indices is None: + self._tool_indices = self._get_tool_indices(tools) + + try: + while True: + current_slice = self._buffer[self.parsed_pos :] + if not current_slice: + break + + if not self.is_inside_tool_call: + # Outside tool call block + next_start = current_slice.find(self.tool_call_start_token) + if next_start == -1: + # Check for partial match at the end + partial_len = self._ends_with_partial_token( + current_slice, self.tool_call_start_token + ) + if partial_len > 0: + text_to_append = current_slice[:-partial_len] + if text_to_append: + normal_text_chunks.append(text_to_append) + self.parsed_pos += len(text_to_append) + break + else: + normal_text_chunks.append(current_slice) + self.parsed_pos += len(current_slice) + continue + elif next_start == 0: + self.parsed_pos += len(self.tool_call_start_token) + self.is_inside_tool_call = True + continue + else: + normal_text_chunks.append(current_slice[:next_start]) + self.parsed_pos += next_start + continue + else: + # Inside tool call block + + # Check for TOOL_CALL_END first + if current_slice.startswith(self.tool_call_end_token): + self.parsed_pos += len(self.tool_call_end_token) + self.is_inside_tool_call = False + self.current_func_name = None + continue + + if not self.current_func_name: + # Skip leading whitespace + if current_slice[0] in (" ", "\n", "\t"): + self.parsed_pos += 1 + continue + + if current_slice.startswith("call:"): + brace_pos = current_slice.find("{") + if brace_pos != -1: + func_name = current_slice[5:brace_pos] + self.current_tool_id += 1 + self.current_func_name = func_name + self.current_tool_name_sent = True + + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get( + func_name, -1 + ), + name=func_name, + parameters="", + ) + ) + self.parsed_pos += brace_pos + 1 + continue + else: + # Incomplete call:name{ + break + else: + # Check for partial matches + if "call:".startswith( + current_slice + ) or self.tool_call_end_token.startswith(current_slice): + break + + # Unexpected content, skip + self.parsed_pos += 1 + continue + else: + # Parsing arguments (looking for balancing }) + match_idx = _find_matching_brace(current_slice) + if match_idx != -1: + args_str = current_slice[:match_idx] + arguments = _parse_gemma4_args(args_str) + + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get( + self.current_func_name, -1 + ), + parameters=json.dumps( + arguments, ensure_ascii=False + ), + ) + ) + self.parsed_pos += match_idx + 1 + self.current_func_name = None + continue + else: + # Incomplete arguments block + break + + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + # Reset parser state to prevent corruption + self.is_inside_tool_call = False + self.current_func_name = None + self._buffer = "" + self.parsed_pos = 0 + + if self.parsed_pos > 0: + self._buffer = self._buffer[self.parsed_pos :] + self.parsed_pos = 0 + + normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" + return StreamingParseResult(calls=calls, normal_text=normal_text) + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 18ed55572cfe..dd7f1be15814 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices from sglang.srt.utils import ( @@ -51,6 +52,8 @@ class ForwardMetadata: window_kv_indices: torch.Tensor window_num_kv_splits: torch.Tensor window_kv_offsets: torch.Tensor + # Separate attn_logits for SWA layers when v_head_dim differs + swa_attn_logits: Optional[torch.Tensor] = None class TritonAttnBackend(AttentionBackend): @@ -94,16 +97,29 @@ def __init__( self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - if ( + # The decode triton kernel derives attn_lse offsets from attn_logits + # strides via integer division by v_head_dim (the "// Lv" trick in + # _fwd_kernel_stage1/stage2), so attn_logits.shape[-1] must exactly + # match the layer's v_head_dim. For hybrid SWA models where SWA and + # full-attention layers use different v_head_dim (e.g. Gemma 4: + # swa=256, full=512), we allocate a second buffer for SWA layers. + full_v_head_dim = model_runner.model_config.v_head_dim + swa_v_head_dim = model_runner.model_config.swa_v_head_dim + if self.sliding_window_size is not None and swa_v_head_dim != full_v_head_dim: + self.v_head_dim = full_v_head_dim + self.swa_v_head_dim = swa_v_head_dim + elif ( model_runner.hybrid_gdn_config is not None or model_runner.kimi_linear_config is not None ): # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + self.swa_v_head_dim = None else: self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ -1 ] + self.swa_v_head_dim = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) @@ -242,6 +258,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): @@ -290,6 +307,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + swa_attn_logits = torch.empty( + (bs, self.num_head, self.max_kv_splits, self.swa_v_head_dim), + dtype=torch.float32, + device=self.device, + ) + else: + swa_attn_logits = None attn_lse = torch.empty( (bs, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -436,6 +461,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_cuda_graph_state( @@ -450,6 +476,19 @@ def init_cuda_graph_state( dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + self.cuda_graph_swa_attn_logits = torch.zeros( + ( + max_num_tokens, + self.num_head, + self.max_kv_splits, + self.swa_v_head_dim, + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.cuda_graph_swa_attn_logits = None self.cuda_graph_attn_lse = torch.zeros( (max_num_tokens, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -520,6 +559,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -558,6 +598,7 @@ def init_forward_metadata_capture_cuda_graph( kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices attn_logits = self.cuda_graph_attn_logits + swa_attn_logits = self.cuda_graph_swa_attn_logits attn_lse = self.cuda_graph_attn_lse max_extend_len = None num_kv_splits = self.cuda_graph_num_kv_splits @@ -659,6 +700,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_forward_metadata_replay_cuda_graph( @@ -819,26 +861,37 @@ def forward_extend( else: o = torch.empty_like(q) - # Save KV cache first (must do this before unified kernel) - if save_kv_cache: - if ( - self.use_mla or layer.k_scale is None - ): # Triton MLA currently doesn't support quantized kv cache - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k, - v, - ) - else: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer - v.clone(), - layer.k_scale, - layer.v_scale, - ) + if k is None and v is None: + pool = forward_batch.token_to_kv_pool + cache_loc = forward_batch.out_cache_loc + if isinstance(pool, SWAKVPool) and pool.layers_mapping[layer.layer_id][1]: + cache_loc = pool.translate_loc_from_full_to_swa(cache_loc) + k_buffer, v_buffer = pool.get_kv_buffer(layer.layer_id) + k = k_buffer[cache_loc] + v = v_buffer[cache_loc] + elif k is None or v is None: + raise ValueError("Both k and v should be None or not None") + else: + # Save KV cache first (must do this before unified kernel) + if save_kv_cache: + if ( + self.use_mla or layer.k_scale is None + ): # Triton MLA currently doesn't support quantized kv cache + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k, + v, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer + v.clone(), + layer.k_scale, + layer.v_scale, + ) logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) @@ -1089,6 +1142,16 @@ def forward_decode( k_descale = 1.0 v_descale = 1.0 + # Select the correctly-sized attn_logits buffer for this layer. + # The triton kernel's // Lv stride trick requires attn_logits.shape[-1] + # to exactly match the layer's v_head_dim. + attn_logits = self.forward_metadata.attn_logits + if ( + self.forward_metadata.swa_attn_logits is not None + and layer.v_head_dim == self.swa_v_head_dim + ): + attn_logits = self.forward_metadata.swa_attn_logits + self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), @@ -1096,7 +1159,7 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), kv_indptr, kv_indices, - self.forward_metadata.attn_logits, + attn_logits, self.forward_metadata.attn_lse, self.forward_metadata.num_kv_splits, self.max_kv_splits, diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index ac0fc72af140..a50b89787f2a 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -168,13 +168,14 @@ def _fwd_kernel( def context_attention_fwd( - q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True, sm_scale=None ): """ q, k, v: [b * s, head, head_dim] b_start_loc: [b] b_seq_len: [b] out: [b * s, head, head_dim] + sm_scale: softmax scale, defaults to 1/sqrt(head_dim) """ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8: BLOCK = 128 @@ -183,7 +184,8 @@ def context_attention_fwd( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 3fd45aac0101..23dba24584e9 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -167,6 +167,7 @@ def __init__( dropout: float = 0.0, flatten_batch: bool = False, softmax_in_single_precision: bool = False, + softmax_scale: float | None = None, **kwargs, ): super().__init__() @@ -176,7 +177,11 @@ def __init__( self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout - self.scale = 1.0 / math.sqrt(self.head_size) + self.scale = ( + softmax_scale + if softmax_scale is not None + else 1.0 / math.sqrt(self.head_size) + ) @staticmethod @lru_cache(maxsize=128) @@ -242,6 +247,7 @@ def forward( bsz: int, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -298,6 +304,7 @@ def forward( attn_mask=attention_mask, dropout_p=self.dropout, is_causal=False, + scale=self.scale, ) # [b, h, s, head_size] --> [b * s, h, head_size] @@ -329,11 +336,13 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" Args: cu_seqlens: [b] + softmax_scale: override softmax scale (default 1/sqrt(head_dim)) Returns: [b * s, h, head_size] """ @@ -354,6 +363,7 @@ def forward( cu_seqlens[1], cu_seqlens[2], is_causal=False, + sm_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -372,6 +382,7 @@ def forward( seq_lens.to(q.device), max_seqlen, is_causal=False, + sm_scale=softmax_scale, ) return output @@ -398,6 +409,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -416,6 +428,7 @@ def forward( cu_seqlens_k=cu_seqlens[0], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -431,6 +444,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) return output @@ -453,6 +467,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -482,6 +497,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ver=4, ) @@ -508,6 +524,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -583,7 +600,7 @@ def forward( raise RuntimeError("offset + len out of bounds; packed indptr is wrong") _, _, head_size = q.shape - scale = head_size**-0.5 + scale = softmax_scale if softmax_scale is not None else head_size**-0.5 output, _ = cudnn_batch_prefill_with_kv_cache( q, @@ -635,6 +652,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -651,6 +669,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) @@ -672,6 +691,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -684,7 +704,6 @@ def forward( if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for npu-graph mode") output = kwargs["output_ws"] - # graph mode: runner already passes seq_lens (int32 on CPU) seq_len_arg = cu_seqlens else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device="cpu") @@ -697,12 +716,14 @@ def forward( _, num_heads, head_size = q.shape num_kv_heads = k.shape[1] + scale_value = softmax_scale if softmax_scale is not None else head_size**-0.5 + torch_npu._npu_flash_attention_unpad( query=q, key=k, value=v, seq_len=seq_len_arg, - scale_value=head_size**-0.5, + scale_value=scale_value, num_heads=num_heads, num_kv_heads=num_kv_heads, out=output, @@ -744,6 +765,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, dropout: float = 0.0, softmax_in_single_precision: bool = False, + softmax_scale: Optional[float] = None, flatten_batch: bool = False, prefix: str = "", proj_bias: bool = True, @@ -808,6 +830,7 @@ def __init__( self.customized_position_embedding_applier = ( customized_position_embedding_applier ) + self.softmax_scale = softmax_scale self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( head_dim=self.head_size, num_heads=self.num_attention_heads_per_partition, @@ -815,6 +838,7 @@ def __init__( dropout=dropout, flatten_batch=flatten_batch, softmax_in_single_precision=softmax_in_single_precision, + softmax_scale=softmax_scale, use_data_parallel=use_data_parallel, workspace_buffer=workspace_buffer, ) @@ -1116,6 +1140,7 @@ def forward( sequence_lengths=sequence_lengths, max_seqlen=max_seqlen, output_ws=attn_output_ws, + softmax_scale=self.softmax_scale, ) assert output.dim() == 3, output.shape diff --git a/python/sglang/srt/layers/clippable_linear.py b/python/sglang/srt/layers/clippable_linear.py new file mode 100644 index 000000000000..a253bb42197a --- /dev/null +++ b/python/sglang/srt/layers/clippable_linear.py @@ -0,0 +1,283 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TP-sharded linear wrappers with per-tensor activation clamping. + +Used by the Gemma 4 vision and audio encoders. Each wrapper owns a parallel +linear and four scalar clip buffers (``input_min/max``, ``output_min/max``) +that default to ±inf (no-op) and are populated from the checkpoint. + +For fused projections (QKV, GateUp), input bounds are shared (the checkpoint +stores identical copies per projection — last write wins during loading) and +output bounds are per-projection. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix + +_INF = float("inf") + + +class ClippableRowParallelLinear(nn.Module): + """``RowParallelLinear`` with input/output activation clamping. + + Checkpoint weight at ``.weight`` is remapped to ``.linear.weight`` + by the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = RowParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableColumnParallelLinear(nn.Module): + """``ColumnParallelLinear`` with input/output activation clamping.""" + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = ColumnParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableQKVParallelLinear(nn.Module): + """Fused QKV projection with per-projection activation clamping. + + Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds + are stored as flat buffers: shared ``input_min/max`` (applied before the + matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.q_size = (total_num_heads // tp_size) * head_size + self.kv_size = (total_num_kv_heads // tp_size) * head_size + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.q_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.q_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.k_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.k_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.v_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.v_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.clamp(hidden_states, self.input_min, self.input_max) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = torch.clamp(q, self.q_output_min, self.q_output_max) + k = torch.clamp(k, self.k_output_min, self.k_output_max) + v = torch.clamp(v, self.v_output_min, self.v_output_max) + return q, k, v + + +class ClippableGLUParallelLinear(nn.Module): + """Fused linear + GLU gating with correct TP sharding. + + Used by the audio encoder's ``LightConv1d``, where a single linear + projects to ``[hidden * 2]`` and GLU splits into value/gate halves. + A plain ``ColumnParallelLinear`` is *incorrect* here under TP because it + shards the output contiguously, mixing value and gate across ranks. + This wrapper uses ``MergedColumnParallelLinear`` to shard each half + independently, then applies GLU (``value * sigmoid(gate)``) on each + rank's correctly-paired shard. + + Output clamping is applied once *after* the GLU gate, using a single + ``output_min/max`` pair (matching the checkpoint layout). + + The checkpoint stores a single fused ``[hidden * 2, input]`` weight. + A custom ``weight_loader`` on the inner param automatically splits it + into value (first half) and gate (second half) shards, so no special + handling is needed in the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = hidden_size // tp_size + + self.linear = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[hidden_size, hidden_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + + # The checkpoint has a single fused weight; MergedColumnParallelLinear + # expects per-shard loading. Wrap the original weight_loader so that + # a call *without* shard_id (the generic load_weights path) splits + # automatically. + orig_loader = self.linear.weight.weight_loader + + def _fused_weight_loader(param, loaded_weight, loaded_shard_id=None): + if loaded_shard_id is not None: + return orig_loader(param, loaded_weight, loaded_shard_id) + half = loaded_weight.shape[0] // 2 + orig_loader(param, loaded_weight[:half], 0) + orig_loader(param, loaded_weight[half:], 1) + + self.linear.weight.weight_loader = _fused_weight_loader + + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + merged, _ = self.linear(x) + value, gate = merged.split([self.proj_size, self.proj_size], dim=-1) + x = value * torch.sigmoid(gate) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableGateUpParallelLinear(nn.Module): + """Fused gate/up projection with per-projection activation clamping. + + Used by the MLP layers in the vision/audio encoders. Owns a single + ``MergedColumnParallelLinear`` for the fused matmul and returns the + two projections separately so the caller can apply its own activation + (e.g. ``SiLU(gate) * up``). + + Output clamping is applied *per-projection before* the caller's + activation, using separate ``gate_output_min/max`` and + ``up_output_min/max`` bounds. + """ + + def __init__( + self, + input_size: int, + intermediate_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = intermediate_size // tp_size + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[intermediate_size, intermediate_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.gate_output_min = nn.parameter.Buffer( + torch.tensor(-_INF), persistent=False + ) + self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.clamp(x, self.input_min, self.input_max) + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.split([self.proj_size, self.proj_size], dim=-1) + gate = torch.clamp(gate, self.gate_output_min, self.gate_output_max) + up = torch.clamp(up, self.up_output_min, self.up_output_max) + return gate, up diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py new file mode 100644 index 000000000000..5f227db82853 --- /dev/null +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -0,0 +1,79 @@ +"""Fused triton kernels for Gemma4 decoder layer operations. + +Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into +a single kernel pass to reduce kernel launch overhead. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gemma_rmsnorm_residual_kernel( + X_ptr, + W_ptr, + Residual_ptr, + Scalar_ptr, + Out_ptr, + stride_x, + stride_r, + stride_o, + N, + eps, + HAS_SCALAR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel: out = rmsnorm(x, w) + residual [* scalar] + + When HAS_SCALAR is True, also multiplies by a scalar loaded from Scalar_ptr. + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32) + r = tl.load(Residual_ptr + row * stride_r + cols, mask=mask, other=0.0).to( + tl.float32 + ) + + var = tl.sum(x * x, axis=0) / N + rrms = tl.rsqrt(var + eps) + out = x * rrms * w + r + + if HAS_SCALAR: + scalar = tl.load(Scalar_ptr).to(tl.float32) + out = out * scalar + + tl.store(Out_ptr + row * stride_o + cols, out.to(x.dtype), mask=mask) + + +def gemma_rmsnorm_residual_scalar( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scalar: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused (rmsnorm(x) + residual) * scalar.""" + assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" + M, N = x.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(x) + + _gemma_rmsnorm_residual_kernel[(M,)]( + x, + weight, + residual, + scalar, + out, + x.stride(0), + residual.stride(0), + out.stride(0), + N, + eps, + HAS_SCALAR=True, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e4960bdb42d6..0db6675e648f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -182,6 +182,11 @@ def forward_cuda( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if x.numel() == 0: return x + # sgl_kernel rmsnorm requires 2D input; reshape higher-rank tensors + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if self.variance_size_override is not None: return self.forward_native(x, residual, post_residual_addition) if is_batch_invariant_mode_enabled(): @@ -205,6 +210,8 @@ def forward_cuda( fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_npu( @@ -458,6 +465,10 @@ def _forward_impl( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition @@ -466,6 +477,8 @@ def _forward_impl( ) return x, residual out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_native( @@ -631,3 +644,88 @@ def forward_npu(self, x): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma4RMSNorm(MultiPlatformOp): + def __init__( + self, + dim: int, + eps: float = 1e-6, + scale_shift: float = 0.0, + with_scale: bool = True, + ): + super().__init__() + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + self.eps = eps + self.scale_shift = scale_shift + + def __repr__(self): + dim = self.weight.shape[0] + return ( + f"{self.__class__.__name__}(dim={dim}, eps={self.eps}, " + f"with_scale={self.with_scale}, scale_shift={self.scale_shift})" + ) + + def _norm(self, x): + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + return x * torch.pow(mean_squared, -0.5) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(x.float()) + if self.with_scale: + normed_output = normed_output * (self.weight.float() + self.scale_shift) + return normed_output.type_as(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + needs_reshape = x.dim() != 2 + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) + if self.with_scale and self.scale_shift == 1.0: + # gemma_rmsnorm: norm(x) * (1 + weight) + out = gemma_rmsnorm(x, self.weight.data, self.eps) + else: + # rmsnorm: norm(x) * weight + # with_scale=False → weight is ones → norm(x) * 1 = norm(x) + # scale_shift=0.0 → standard RMSNorm without +1 shift + out = rmsnorm(x, self.weight.data, self.eps) + + if needs_reshape: + out = out.reshape(original_shape) + return out + + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + # sgl_kernel's gemma_rmsnorm is not available on ROCm; + # delegate to the pure-PyTorch implementation. + return self.forward_native(x) + + +class RMSNormWithoutScale(MultiPlatformOp): + def __init__(self, hidden_size: int, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward_native(self, x): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return x.to(orig_dtype) + + def forward_cuda(self, x): + return self.forward_native(x) + + def extra_repr(self): + return f"{self.hidden_size}, eps={self.eps}" diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..f0eb57ab8dc0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..60adcf03cea9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..8ff7c371dab5 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..48b07c17d5b7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 94f9a1375c14..5fce65159a59 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, List, Optional +logger = logging.getLogger(__name__) + import torch import torch.nn.functional as F from torch.nn.parameter import Parameter @@ -469,28 +472,37 @@ def forward_cuda( topk_weights = torch.ones_like( topk_weights, dtype=torch.float32 ) # topk_weights must be FP32 (float32) - output = fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - expert_mask=layer.expert_mask_gpu, - ) - return StandardCombineInput(hidden_states=output) - else: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - b13=getattr(layer, "w13_weight_bias", None), - b2=getattr(layer, "w2_weight_bias", None), - ) - return self.runner.run(dispatch_output, quant_info) + try: + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + expert_mask=layer.expert_mask_gpu, + ) + return StandardCombineInput(hidden_states=output) + except RuntimeError as e: + # AITER CK fused_moe may not support all GEMM dimensions + # (e.g. Gemma4 MoE with 128 experts × 704 intermediate size). + # Fall through to Triton MoE runner below. + logger.warning_once( + f"AITER CK fused_moe failed ({e}), " + "falling back to Triton MoE runner." + ) + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + b13=getattr(layer, "w13_weight_bias", None), + b2=getattr(layer, "w2_weight_bias", None), + ) + return self.runner.run(dispatch_output, quant_info) def forward_cpu( self, diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 99a3f11ca05f..2ccfdddfc94d 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -106,6 +106,15 @@ def __init__( ) self.position_cos, self.position_sin = None, None + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 27e28577c96e..d058ea08abb6 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -21,6 +21,7 @@ DynamicNTKAlphaRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, FourierRotaryEmbedding, + Gemma4RotaryEmbedding, Llama3RotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding, ) @@ -326,6 +327,15 @@ def get_rope( long_factor, **extra_kwargs, ) + elif scaling_type == "proportional": + rotary_emb = Gemma4RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/python/sglang/srt/layers/rotary_embedding/rope_variant.py b/python/sglang/srt/layers/rotary_embedding/rope_variant.py index 28aaae598bc8..2fe9d5da280d 100644 --- a/python/sglang/srt/layers/rotary_embedding/rope_variant.py +++ b/python/sglang/srt/layers/rotary_embedding/rope_variant.py @@ -866,3 +866,66 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache + + +class Gemma4RotaryEmbedding(RotaryEmbedding): + """Gemma4-specific RoPE with cross-mixing. + + Instead of rotating the first `rotary_dim` dimensions contiguously, + splits the head into two halves and applies rotation across both. + + For a head_dim of D and rotary_dim of R: + - Standard RoPE rotates: [0, R) + - Gemma4 RoPE rotates: [0, R/2) cross-mixed with [D/2, D/2 + R/2) + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + # Store angles before calling super().__init__ + # rotary_dim is already scaled by partial_rotary_factor in get_rope + # For Gemma4: head_size=512, partial_rotary_factor=0.25 -> rotary_dim=128 + self.rope_angles = rotary_dim // 2 # Number of rotation angles per half + self.nope_angles = (head_size // 2) - self.rope_angles # Non-rotated per half + + super().__init__( + head_size, + head_size, + max_position_embeddings, + base, + is_neox_style, + dtype, + ) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute frequencies only for the rotated dimensions. + + Non-rotated dims are padded with 0.0 to produce identity rotation. + """ + freq_exponents = ( + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size + ) + inv_freq = 1.0 / (base**freq_exponents) + + # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) + if self.nope_angles > 0: + inv_freq = torch.cat( + [ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ] + ) + return inv_freq + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 96b0e3844914..80f24e6dfa56 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -1,5 +1,4 @@ import logging -import weakref from typing import Dict, List, Optional, Tuple import torch @@ -306,7 +305,7 @@ def __init__( self.clear() self._kvcache = kvcache - self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) + self._kvcache.register_mapping(self.full_to_swa_index_mapping) def available_size(self): return min( diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 0481cae0eeba..6a38e7ebad9a 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -95,11 +95,12 @@ def __init__( ) if hidden_activation != "gelu_pytorch_tanh": raise ValueError( - "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + f"{self.__class__.__name__} uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_activation` to " "`gelu_pytorch_tanh`." ) self.act_fn = GeluAndMul() + self.prefix = prefix def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py new file mode 100644 index 000000000000..db825165fe29 --- /dev/null +++ b/python/sglang/srt/models/gemma4_audio.py @@ -0,0 +1,873 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SGLang-native TP-sharded audio encoder for Gemma 4. + +Architecture: Conformer-based USM (Universal Speech Model) with SSCP convolution +projection. Adapted from gemma3n_audio.py with Gemma 4 specific changes: + - Activation clamping (clippable linears) on all conformer linears + - per_dim_key_scale in attention + - LayerNorm (not CumulativeGroupNorm) in SSCP convolution blocks + - Semicausal SSCP padding + - Mask propagation through SSCP + - Output projection (hidden_size -> output_proj_dims) +""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Gemma4AudioConfig + +from sglang.srt.layers.clippable_linear import ( + ClippableColumnParallelLinear, + ClippableGLUParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, +) +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs + +# SSCP convolution constants (no longer in config.json, never varied across models) +_SSCP_INPUT_FEAT_SIZE = 128 +_SSCP_CONV_KERNEL_SIZES = ((3, 3), (3, 3)) +_SSCP_CONV_STRIDE_SIZES = ((2, 2), (2, 2)) + +# --------------------------------------------------------------------------- +# Relative Position Embedding +# --------------------------------------------------------------------------- + + +class Gemma4AudioRelativePositionEmbedding(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + tp_size = get_attention_tp_size() + total_num_heads = config.num_attention_heads + self.channels = config.hidden_size + self.head_dim = self.channels // total_num_heads + self.num_heads = total_num_heads // tp_size + self.max_backward = max(0, config.attention_context_left - 1) + self.max_forward = config.attention_context_right + + self.pos_proj = ColumnParallelLinear( + self.channels, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("pos_proj", prefix), + ) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales) * -log_timescale_increment + ) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos( + self, position: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + assert position.ndim == 2 + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to( + device=position.device, dtype=torch.float32 + ) + timing_signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1 + ) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = F.pad(term_bd_before_shift, padding_tuple) + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + term_bd_sliced = term_bd_reshaped[ + :, :, :, : query_block_size * key_context_size + ] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = ( + queries.shape + ) + _, _, key_context_size, _, _ = keys.shape + + pos_indices = torch.arange( + self.max_backward, -self.max_forward - 1, -1, device=queries.device + ).unsqueeze(0) + max_span_plus_1 = pos_indices.shape[1] + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + # pos_proj is a ColumnParallelLinear (no implicit dtype promotion); + # project in weight dtype, then cast back to queries' dtype for the matmuls. + projected_sin_emb, _ = self.pos_proj( + sin_emb_timing_signal.to(self.pos_proj.weight.dtype) + ) + projected_sin_emb = projected_sin_emb.to(queries.dtype) + sin_emb = projected_sin_emb.reshape( + 1, max_span_plus_1, self.num_heads, self.head_dim + ).squeeze(0) + + queries_p = queries.permute(0, 3, 1, 2, 4) + keys_p_t = keys.permute(0, 3, 1, 4, 2) + term_ac = torch.matmul(queries_p, keys_p_t) + + q_permuted = queries.permute(0, 3, 1, 2, 4) + s_permuted = sin_emb.permute(1, 2, 0) + q_reshaped = q_permuted.reshape( + batch_size, num_heads, num_query_blocks * query_block_size, head_dim + ) + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + + return term_ac + term_bd_shifted + + +# --------------------------------------------------------------------------- +# Local Dot-Product Attention (with per_dim_key_scale) +# --------------------------------------------------------------------------- + + +class Gemma4AudioAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + tp_size = get_attention_tp_size() + total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = self.hidden_size // total_num_heads + self.num_heads = total_num_heads // tp_size + + self.chunk_size = config.attention_chunk_size + self.max_future_horizon = config.attention_context_right + self.max_past_horizon = max(0, config.attention_context_left - 1) + self.attention_logits_soft_cap = config.attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + + self.relative_position_embedding = Gemma4AudioRelativePositionEmbedding( + config, + quant_config, + prefix=add_prefix("relative_position_embedding", prefix), + ) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + + self.qkv = ClippableQKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_heads, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + + self.q_scale = (self.head_dim**-0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + # ------ block / context helpers (identical to Gemma3n) ------------------ + + def _pad_dim1( + self, x: torch.Tensor, dim10_val: int, dim11_val: int + ) -> torch.Tensor: + padding_tuple = [0] * x.ndim * 2 + dim_idx_from_end = x.ndim - 2 + start_idx_for_dim = 2 * dim_idx_from_end + padding_tuple[start_idx_for_dim] = dim10_val + padding_tuple[start_idx_for_dim + 1] = dim11_val + return F.pad(x, tuple(padding_tuple)) + + def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + if (padding_len := num_blocks * self.chunk_size - t) > 0: + x = self._pad_dim1(x, 0, padding_len) + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + return x.reshape(permute_dims).contiguous() + + def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor: + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + x = self._pad_dim1(x, pad_left, pad_right) + frame_len = self.context_size + frame_step = self.chunk_size + x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step) + if x.ndim > 2 and x_unfolded.ndim > 3: + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + return x_unfolded.contiguous() + + # ------ forward --------------------------------------------------------- + + def forward( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + q, k, v = self.qkv(x) + qkv_shape = (*x.shape[:-1], self.num_heads, self.head_dim) + query_states = q.float().reshape(qkv_shape).contiguous() + key_states = k.float().reshape(qkv_shape).contiguous() + value_states = v.float().reshape(qkv_shape).contiguous() + + per_dim_scale_sp = F.softplus(self.per_dim_scale) + broadcast_shape = (1, 1, 1, self.head_dim) + query_states = ( + query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) + ) + + key_states = key_states * self.k_scale + + batch_size, q_time = query_states.shape[:2] + + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + original_valid_mask = ~mask + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[0] == batch_size + and extracted_valid_mask_blocks.shape[1] == num_query_blocks + and extracted_valid_mask_blocks.shape[2] + * extracted_valid_mask_blocks.shape[3] + == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze( + 1 + ).unsqueeze(-2) + condition_from_causality = ( + causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + ) + + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), + ) + + logits = self.relative_position_embedding(query_blocks, key_blocks) + + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + logits = torch.where( + final_condition_for_where, + logits, + self.config.attention_invalid_logits_value, + ) + + probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to( + dtype=value_blocks.dtype + ) + + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute( + 0, 1, 3, 2, 4 + ) + context_vectors = context_vectors.reshape( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + context_vectors = context_vectors[:, :q_time] + return context_vectors + + +# --------------------------------------------------------------------------- +# SSCP (Sub-Sample Convolution Projection) +# --------------------------------------------------------------------------- + + +class Gemma4AudioSSCPConvBlock(nn.Module): + """Single 2D conv block with LayerNorm and semicausal padding.""" + + def __init__( + self, + config: Gemma4AudioConfig, + idx: int, + input_freq_dim: int, + ): + super().__init__() + self.config = config + + conv_channels = config.subsampling_conv_channels + in_channels = 1 if idx == 0 else conv_channels[idx - 1] + out_channels = conv_channels[idx] + kernel_t, kernel_f = _SSCP_CONV_KERNEL_SIZES[idx] + stride_t, stride_f = _SSCP_CONV_STRIDE_SIZES[idx] + self.time_stride = stride_t + + # Semicausal padding (hardcoded — streaming is not supported) + pad_t_top = kernel_t // 2 + pad_t_bottom = kernel_t // 2 + + pad_f_left = 1 + pad_f_right = 1 + + self.manual_padding = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom) + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_t, kernel_f), + stride=(stride_t, stride_f), + padding=(0, 0), + bias=False, + ) + + f_in_padded = input_freq_dim + pad_f_left + pad_f_right + self.f_out_conv = (f_in_padded - kernel_f) // stride_f + 1 + + self.norm = nn.LayerNorm( + [out_channels], + eps=config.rms_norm_eps, + elementwise_affine=True, + bias=False, + ) + self.activation = nn.ReLU() + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + mask_for_fill = audio_mel_mask.unsqueeze(1).unsqueeze(-1) + audio_encodings = audio_encodings.masked_fill(mask_for_fill, 0.0) + + audio_encodings_padded = F.pad( + audio_encodings, self.manual_padding, mode="constant", value=0.0 + ).to(self.conv.weight.dtype) + audio_encodings_conv = self.conv(audio_encodings_padded) + + output_mask = audio_mel_mask[:, :: self.time_stride][ + :, : audio_encodings_conv.shape[2] + ] + + x = audio_encodings_conv.permute(0, 2, 3, 1) + x_normed = self.norm(x) + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed), output_mask + + +class Gemma4AudioSubSampleConvProjection(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + conv_channels = config.subsampling_conv_channels + + current_f = _SSCP_INPUT_FEAT_SIZE + calculated_f_out_dims = [] + + for i in range(2): + kernel_h, kernel_w = _SSCP_CONV_KERNEL_SIZES[i] + stride_h, stride_w = _SSCP_CONV_STRIDE_SIZES[i] + + pad_f_left = 1 + pad_f_right = 1 + f_in_padded = current_f + pad_f_left + pad_f_right + f_out = (f_in_padded - kernel_w) // stride_w + 1 + calculated_f_out_dims.append(f_out) + current_f = f_out + + self.conv_0 = Gemma4AudioSSCPConvBlock( + idx=0, + input_freq_dim=_SSCP_INPUT_FEAT_SIZE, + config=config, + ) + self.conv_1 = Gemma4AudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], + config=config, + ) + + final_c_out = conv_channels[-1] + final_f_out = calculated_f_out_dims[-1] + self.input_proj_in_features = final_c_out * final_f_out + + self.input_proj_linear = RowParallelLinear( + self.input_proj_in_features, + config.hidden_size, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + prefix=add_prefix("input_proj_linear", prefix), + ) + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x, mask = self.conv_0(audio_encodings_reshaped, audio_mel_mask) + x, mask = self.conv_1(x, mask) + b, c_out, t_out, f_out = x.shape + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.reshape(b, t_out, f_out * c_out) + output, _ = self.input_proj_linear(output_flattened) + return output, mask + + +# --------------------------------------------------------------------------- +# Conformer Blocks +# --------------------------------------------------------------------------- + + +class Gemma4AudioConformerAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.post_in_features = config.hidden_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_attn_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.attn = Gemma4AudioAttention( + config, quant_config, prefix=add_prefix("attn", prefix) + ) + self.post = ClippableRowParallelLinear( + self.post_in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("post", prefix), + ) + self.post_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn( + audio_encodings_norm, audio_mel_mask, causal_valid_mask + ) + + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape( + b, t, num_heads * head_dim + ).to(dtype=audio_encodings_input_to_attn.dtype) + + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma4AudioConformerFeedForward(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.ffw_layer_1 = ClippableColumnParallelLinear( + config.hidden_size, + config.hidden_size * 4, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_1", prefix), + ) + self.ffw_layer_2 = ClippableRowParallelLinear( + config.hidden_size * 4, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_2", prefix), + ) + self.post_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.post_layer_scale = config.residual_weight + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.ffw_layer_1(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma4AudioConformerLightConv1d(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.causal_padding = config.conv_kernel_size - 1 + tp_size = get_attention_tp_size() + hidden_per_tp = config.hidden_size // tp_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 + ) + self.linear_start = ClippableGLUParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_start", prefix), + ) + self.depthwise_conv1d = nn.Conv1d( + in_channels=hidden_per_tp, + out_channels=hidden_per_tp, + kernel_size=config.conv_kernel_size, + stride=1, + padding=0, + groups=hidden_per_tp, + bias=False, + ) + self.conv_norm = Gemma4RMSNorm( + hidden_per_tp, eps=config.rms_norm_eps, scale_shift=0.0 + ) + + tp_rank = get_attention_tp_rank() + + def _shard_dim0(param, loaded_weight, _rank=tp_rank, _tp=tp_size): + shard = param.shape[0] + loaded_weight = loaded_weight.narrow(0, _rank * shard, shard) + param.data.copy_(loaded_weight) + + set_weight_attrs(self.depthwise_conv1d.weight, {"weight_loader": _shard_dim0}) + set_weight_attrs(self.conv_norm.weight, {"weight_loader": _shard_dim0}) + + self.linear_end = ClippableRowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_end", prefix), + ) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + audio_encodings_permuted_padded = F.pad( + audio_encodings_permuted, (self.causal_padding, 0) + ) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + return audio_encodings + audio_encodings_residual + + +class Gemma4AudioConformerBlock(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_start", prefix) + ) + self.attention = Gemma4AudioConformerAttention( + config, quant_config, prefix=add_prefix("attention", prefix) + ) + self.lconv1d = Gemma4AudioConformerLightConv1d( + config, quant_config, prefix=add_prefix("lconv1d", prefix) + ) + self.ffw_layer_end = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_end", prefix) + ) + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + self.norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention( + audio_encodings, audio_mel_mask, causal_valid_mask + ) + validity_mask_for_lconv = ~audio_mel_mask + audio_encodings_for_lconv_input = ( + audio_encodings + * validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype) + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return self.norm(audio_encodings) + + +# --------------------------------------------------------------------------- +# Top-level Encoder +# --------------------------------------------------------------------------- + + +class Gemma4AudioEncoder(nn.Module): + """SGLang-native TP-sharded Gemma 4 audio encoder (USM Conformer + SSCP).""" + + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection( + config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix) + ) + self.conformer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma4AudioConformerBlock( + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("conformer", prefix), + ) + + if config.output_proj_dims is not None: + self.output_proj = RowParallelLinear( + config.hidden_size, + config.output_proj_dims, + bias=True, + input_is_parallel=False, + quant_config=quant_config, + prefix=add_prefix("output_proj", prefix), + ) + else: + self.output_proj = None + + # Precompute causal_valid_mask — depends only on static config values. + chunk_size = config.attention_chunk_size + max_future_horizon = config.attention_context_right + max_past_horizon = max(0, config.attention_context_left - 1) + upper_diagonal = max_past_horizon + max_future_horizon + context_size = chunk_size + max_past_horizon + max_future_horizon + + lower_causal_mask = torch.tril( + torch.ones((context_size, chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((chunk_size, context_size), dtype=torch.bool), + diagonal=upper_diagonal, + ) + local_causal_valid_mask = torch.ones( + (chunk_size, context_size), dtype=torch.bool + ) + self.register_buffer( + "causal_valid_mask", + local_causal_valid_mask * lower_causal_mask * upper_causal_mask, + persistent=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> Tuple[torch.Tensor, torch.BoolTensor]: + """Encode a batch of mel spectrograms. + + Args: + audio_mel: [batch, num_frames, mel_bins] + audio_mel_mask: [batch, num_frames], True = padding + + Returns: + audio_encodings: [batch, reduced_frames, hidden_size/output_proj_dims] + audio_mel_mask: [batch, reduced_frames], True = padding + """ + audio_encodings, current_mask = self.subsample_conv_projection( + audio_mel, audio_mel_mask + ) + + for block in self.conformer: + audio_encodings = block( + audio_encodings, current_mask, self.causal_valid_mask + ) + + if self.output_proj is not None: + audio_encodings, _ = self.output_proj(audio_encodings) + + if current_mask.shape[1] != audio_encodings.shape[1]: + target_len = audio_encodings.shape[1] + if target_len > current_mask.shape[1]: + current_mask = F.pad( + current_mask, (0, target_len - current_mask.shape[1]), value=True + ) + else: + current_mask = current_mask[:, :target_len] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py new file mode 100644 index 000000000000..544406119243 --- /dev/null +++ b/python/sglang/srt/models/gemma4_causal.py @@ -0,0 +1,1009 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import re +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import ( + Gemma4TextConfig, + PretrainedConfig, + PreTrainedModel, +) + +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar +from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix, make_layers + +logger = logging.getLogger(__name__) + + +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +Gemma4MLP = Gemma3MLP +Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding + + +class Gemma4Router(nn.Module): + """Router for Gemma4 MoE that preprocesses input before projection. + + Applies RMSNorm (no learned weight), root_size scaling + (hidden_size^{-0.5}), then a learned per-dimension scale before + projecting to expert logits. + + This preprocessing is applied ONLY to the router's input, not to + the expert MLPs' input. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + # RMSNorm without learned weight — pure normalization only + self.norm = Gemma4RMSNorm( + self.hidden_size, eps=config.rms_norm_eps, with_scale=False + ) + # Per-dimension learned scale, applied after norm + root_size + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + # Constant 1/sqrt(hidden_size) scaling factor + self.register_buffer( + "root_size", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + # Project to expert logits; replicated across TP for consistent routing + self.proj = ReplicatedLinear( + self.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("proj", prefix), + ) + self._fused_scale: Optional[torch.Tensor] = None + + def fuse_scale(self): + """Pre-compute scale * root_size. Call after weights are loaded.""" + self._fused_scale = (self.scale * self.root_size).to(self.scale.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns raw router logits [T, E].""" + x = self.norm(x) + if self._fused_scale is None: + self.fuse_scale() + x = x * self._fused_scale.to(x.dtype) + router_logits, _ = self.proj(x) + return router_logits + + +class Gemma4MoE(nn.Module): + """Mixture of Experts for Gemma4. + + Wraps MoE implementation with custom routing. The router projection is + external (Gemma4Router) — this class only handles expert dispatch. + + Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale is folded into routing weights for mathematical + correctness with MoE's fused kernel. + """ + + def __init__( + self, + hidden_size: int, + layer_id: int, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.num_experts = config.num_experts + self.tp_size = get_tensor_model_parallel_world_size() + + # Per-expert output scale folded into routing weights so that + # MoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + # Capture param directly to avoid closing over self in the routing closure. + per_expert_scale = self.per_expert_scale + + def routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, # always True for Gemma4; softmax identity only holds when renormalizing + ) -> tuple[torch.Tensor, torch.Tensor]: + # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), + # so we softmax only the top-k logits (fewer kernel launches). + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) + + # Fold per_expert_scale into routing weights + topk_weights = topk_weights * per_expert_scale[topk_ids].to( + topk_weights.dtype + ) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + self.topk = TopK( + top_k=config.top_k_experts, + layer_id=layer_id, + custom_routing_function=routing_function, + ) + + experts_type = get_moe_impl_class(quant_config) + + self.experts = experts_type( + num_experts=config.num_experts + + get_global_server_args().ep_num_redundant_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + layer_id=layer_id, + top_k=config.top_k_experts, + quant_config=quant_config, + prefix=add_prefix("experts", prefix), + activation="gelu", + reduce_results=True, + ) + + def forward( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + topk_output = self.topk(hidden_states, router_logits) + hidden_states = self.experts(hidden_states, topk_output) + return hidden_states.view(num_tokens, hidden_dim) + + +class Gemma4Attention(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma4TextConfig, + head_dim: int, + max_position_embeddings: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_id = layer_id + self.config = config + tp_size = get_tensor_model_parallel_world_size() + + layer_type = config.layer_types[layer_id] + self.sliding_window = ( + config.sliding_window if layer_type == "sliding_attention" else None + ) + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if layer_type == "sliding_attention": + self.total_num_kv_heads = getattr( + config, "swa_num_key_value_heads", config.num_key_value_heads + ) + else: + self.total_num_kv_heads = config.num_key_value_heads + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + hidden_size = config.hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.k_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.v_norm = Gemma4RMSNorm( + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False + ) + + if layer_type in config.rope_parameters: + rope_parameters = dict(config.rope_parameters[layer_type]) + else: + rope_parameters = dict( + rope_type="default", + rope_theta=10000.0, + ) + + # KV sharing logic + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = ( + layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 + ) + + self.kv_shared_layer_index = None + if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[self.layer_id] + if current_layer_type not in prev_layers: + raise ValueError( + f"KV sharing layer {self.layer_id} has type '{current_layer_type}' " + f"but no matching type found in layers 0..{first_kv_shared_layer_idx - 1}. " + f"Available types: {set(prev_layers)}" + ) + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_parameters.get("rope_theta", 10000.0), + rope_scaling={"rope_type": rope_parameters.get("rope_type", "default")}, + partial_rotary_factor=rope_parameters.get("partial_rotary_factor", 1.0), + is_neox_style=True, + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + 1, # scaling factor + num_kv_heads=self.num_kv_heads, + layer_id=( + self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id + ), + logit_cap=0.0, + sliding_window_size=self.sliding_window, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + # Check if we should use shared KV cache + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None: + # For KV shared layers, we skip K/V computation and normalization + # The RadixAttention will handle retrieving shared KV from cache + k = None + v = None + else: + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + + # Apply rotary embedding + if k is not None: + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + else: + # Rotary embedding requires a key input; use zeros since KV is shared from another layer + dummy_k = torch.zeros_like(q[:, : self.kv_size]) + q, _ = self.rotary_emb(positions, q, dummy_k) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + attn_output = self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=not self.is_kv_shared_layer, + ) + if attn_output.dim() == 3: + attn_output = attn_output.flatten(-2, -1) + output, _ = self.o_proj(attn_output) + + return output + + +class Gemma4DecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) + + self.layer_id = layer_id + + # Gemma 4 uses different head dimensions for sliding vs full attention + layer_type = config.layer_types[layer_id] + self.is_full_attention = layer_type == "full_attention" + if self.is_full_attention: + head_dim = config.head_dim # following sglang naming + else: + head_dim = getattr(config, "swa_head_dim", config.head_dim) + + self.self_attn = Gemma4Attention( + layer_id=layer_id, + config=config, + max_position_embeddings=config.max_position_embeddings, + head_dim=head_dim, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + is_kv_shared_layer = self.layer_id >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = ( + getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer + ) + layer_intermediate_size = config.intermediate_size * ( + 2 if use_double_wide_mlp else 1 + ) + + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=layer_intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-Layer Embedding (PLE) components — present in each decoder layer + if self.hidden_size_per_layer_input > 0: + # Gate: projects hidden_states → per-layer dim for gating + self.per_layer_input_gate = ReplicatedLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_input_gate", prefix), + ) + # Projection: projects gated per-layer input back → hidden size + self.per_layer_projection = ReplicatedLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_projection", prefix), + ) + self.post_per_layer_input_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.per_layer_input_gate = None + self.per_layer_projection = None + self.post_per_layer_input_norm = None + + # Parallel MoE + self.enable_moe_block = getattr(config, "enable_moe_block", False) + if self.enable_moe_block: + self.router = Gemma4Router( + config, + quant_config=quant_config, + prefix=add_prefix("router", prefix), + ) + self.moe = Gemma4MoE( + hidden_size=self.hidden_size, + layer_id=layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("moe", prefix), + ) + + self.post_feedforward_layernorm_1 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.router = None + self.moe = None + self.post_feedforward_layernorm_1 = None + self.post_feedforward_layernorm_2 = None + self.pre_feedforward_layernorm_2 = None + + self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + self.has_ple = self.hidden_size_per_layer_input > 0 + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + # Gemma4 residual pattern following JAX implementation: + # 1. input_norm(x) -> attn -> post_attn_norm -> ADD residual + # 2. pre_ff_norm -> mlp -> post_ff_norm -> ADD residual + # + # Optimization: fuse "post_attn_norm(h) + residual; pre_ff_norm(...)" + # into "post_attn_norm(h); pre_ff_norm(h, residual)" using + # gemma_fused_add_rmsnorm which computes: + # residual = h + residual (in-place) + # h = gemma_norm(residual) + residual = hidden_states + + # Apply input layernorm + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.enable_moe_block: + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + # Also need raw (unfused) residual for router and pre_ff_norm_2 + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + # For MoE: router and pre_ff_norm_2 need the unfused residual + # (which is now updated to post_attn_out + old_residual) + moe_input = residual + + # Dense MLP branch + hidden_states_1 = self.mlp(hidden_states) + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states_1) + + # MoE branch: router sees residual (= post_attn_out + old_residual) + router_logits = self.router(moe_input) + hidden_states_2 = self.pre_feedforward_layernorm_2(moe_input) + hidden_states_2 = self.moe(hidden_states_2, router_logits) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine branches + hidden_states = hidden_states_1 + hidden_states_2 + else: + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + + if not self.has_ple and hidden_states.is_cuda and hidden_states.dim() == 2: + # Fused: (post_ff_norm(h) + residual) * layer_scalar in one kernel + norm = self.post_feedforward_layernorm + hidden_states = gemma_rmsnorm_residual_scalar( + hidden_states, + norm.weight.data, + residual, + self.layer_scalar, + norm.variance_epsilon, + ) + else: + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + if self.has_ple and per_layer_input is not None: + gate, _ = self.per_layer_input_gate(hidden_states) + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + hidden_states = hidden_states * self.layer_scalar + return hidden_states, None + + +class Gemma4TextModel(PreTrainedModel): + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.padding_idx = getattr(config, "pad_token_id", None) + + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=self.config.hidden_size**0.5, # embedded normalizer + ) + + # Per-layer input embeddings + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) + self.vocab_size_per_layer_input = ( + getattr(config, "vocab_size_per_layer_input", None) or config.vocab_size + ) + + if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: + self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( + self.vocab_size_per_layer_input, + config.num_hidden_layers * self.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=self.hidden_size_per_layer_input**0.5, + ) + + self.per_layer_model_projection = ReplicatedLinear( + self.hidden_size, + config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_model_projection", prefix), + ) + + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + config.rms_norm_eps, + ) + self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)) + self.per_layer_projection_scale = torch.tensor( + config.hidden_size**-0.5, + ) + else: + self.embed_tokens_per_layer = None + self.per_layer_model_projection = None + self.per_layer_projection_norm = None + self.per_layer_input_scale = None + self.per_layer_projection_scale = None + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma4DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + if self.embed_tokens_per_layer is None: + return None + + # Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may + # be smaller than the main vocab_size). Following Gemma3n pattern. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + + # Get packed per-layer embeddings: (num_tokens, total_ple_dim) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + + # Apply embed_scale (sqrt of per-layer hidden dim) + # Already done in embedding layer + # per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_embeds = per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + return per_layer_embeds + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project inputs_embeds and combine with per_layer_inputs. + + Following HF/Gemma3n reference: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} (Gemma4ScaledLinear w_scale) + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + + # Project from hidden_size to total_ple_dim + per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds) + + # Apply w_scale (HF: Gemma4ScaledLinear with w_scale=hidden_size^{-0.5}) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + # Normalize + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + # Combine: (projection + per_layer_inputs) * scale + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if input_ids is not None: + input_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs) + + hidden_states = input_embeds + + for layer_idx, layer in enumerate(self.layers): + if per_layer_inputs is not None: + per_layer_input = per_layer_inputs[:, layer_idx, :] + else: + per_layer_input = None + layer_outputs = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_input, + forward_batch=forward_batch, + **kwargs, + ) + hidden_states = layer_outputs[0] + residual = layer_outputs[1] if len(layer_outputs) > 1 else None + + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Gemma4ForCausalLM(PreTrainedModel): + config_class = Gemma4TextConfig + base_model_prefix = "language_model" + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = False + + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.model = Gemma4TextModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.logits_processor = LogitsProcessor(config) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> LogitsProcessor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + per_layer_inputs, + **kwargs, + ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + if not getattr(self.config, "attention_k_eq_v", False): + return set() + return { + i for i, lt in enumerate(self.config.layer_types) if lt == "full_attention" + } + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] + num_experts = self.config.num_experts + + k_eq_v_layers = self._get_k_eq_v_layers() + + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + name = name.replace("model.language_model.", "model.") + + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") + + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). + orig_name = name + for param_name, weight_name, shard_ids in expert_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) + break + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") + break + else: + name = orig_name + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) + return loaded_params + + +EntryClass = Gemma4ForCausalLM diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py new file mode 100644 index 000000000000..4618129fab7a --- /dev/null +++ b/python/sglang/srt/models/gemma4_mm.py @@ -0,0 +1,878 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import logging +import re +from functools import lru_cache +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import ( + Gemma4AudioConfig, + Gemma4Config, + Gemma4TextConfig, + Gemma4VisionConfig, + PreTrainedModel, +) + +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + flatten_nested_list, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder +from sglang.srt.models.gemma4_causal import Gemma4TextModel +from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder +from sglang.srt.utils import add_prefix +from sglang.srt.utils.hf_transformers_utils import get_processor + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + + +class Gemma4ImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Gemma4AudioInputs(TypedDict): + input_features_padded: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length, num_features)`""" + input_features_mask: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length)`""" + + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects vision/audio soft tokens into LM embedding space.""" + + def __init__( + self, + multimodal_config: Union[Gemma4AudioConfig, Gemma4VisionConfig], + text_config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + + # Audio tower uses output_proj_dims (1536) rather than hidden_size + # (1024); vision uses hidden_size (768) directly. + embedding_dim = ( + getattr(multimodal_config, "output_proj_dims", None) + or multimodal_config.hidden_size + ) + + self.embedding_projection = ReplicatedLinear( + embedding_dim, + self.text_hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("embedding_projection", prefix), + ) + + self.embedding_pre_projection_norm = Gemma4RMSNorm( + embedding_dim, + eps=self.eps, + with_scale=False, + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + """Project soft tokens from a multimodal tower into LM space.""" + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + embs_proj, _ = self.embedding_projection(embs_normed) + return embs_proj + + +class Gemma4ForConditionalGeneration(PreTrainedModel): + config_class = Gemma4Config + """Gemma4 multimodal model for conditional generation.""" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + + prefix = add_prefix("model", prefix) + + self.vision_tower = Gemma4VisionEncoder( + config=config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) + + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_vision", prefix), + ) + + # Audio components + if getattr(config, "audio_config", None) is not None: + self.audio_tower = Gemma4AudioEncoder( + config=config.audio_config, + quant_config=quant_config, + prefix=add_prefix("audio_tower", prefix), + ) + self.embed_audio = Gemma4MultimodalEmbedder( + config.audio_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_audio", prefix), + ) + else: + self.audio_tower = None + self.embed_audio = None + + self.vocab_size = config.text_config.vocab_size + self.vocab_size_per_layer_input = getattr( + config.text_config, + "vocab_size_per_layer_input", + config.text_config.vocab_size, + ) + + # Text model + self.language_model = Gemma4TextModel( + config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + + # Create logits processor for the multimodal model + self.logits_processor = LogitsProcessor(config.text_config) + + self.post_init() + + def pad_input_ids( + self, + input_ids: List[int], + mm_inputs: MultimodalInputs, + ) -> List[int]: + """Pad input IDs with image and audio tokens.""" + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_input_embeddings(self) -> nn.Embedding: + return self.language_model.get_input_embeddings() + + def get_attention_sliding_window_size(self): + return getattr(self.config.text_config, "sliding_window", -1) - 1 + + def prepare_attn_masks( + self, + forward_batch: ForwardBatch, + input_ids: torch.Tensor, + mask_dtype: torch.dtype, + ): + """Prepare bidirectional attention masks for image tokens. + + Gemma 4 uses bidirectional attention for image soft tokens + during prefill. Following the HF implementation, bidirectional attention + is only enabled within each individual image group (same-item + tokens), not across items. + Currently only the TritonAttnBackend supports this. + + TODO(kpham-sgl): Guard appropriately for gemma3_mm.py:prepare_attn_masks() + """ + if not isinstance(forward_batch.attn_backend, TritonAttnBackend): + logger.warning_once( + "Bidirectional attention for image tokens requires TritonAttnBackend. " + "Falling back to causal attention, which may degrade image quality." + ) + return + assert forward_batch.forward_mode == ForwardMode.EXTEND + + bidirectional_attn_masks_list = [] + bidirectional_attn_mask_indptr = torch.zeros( + forward_batch.batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + + split_images = [] + + for i in range(forward_batch.batch_size): + extend_seq_len = forward_batch.extend_seq_lens[i] + prefix_len = forward_batch.extend_prefix_lens[i] + bidirectional_attn_mask = torch.zeros( + extend_seq_len, + extend_seq_len + prefix_len, + dtype=mask_dtype, + device=input_ids.device, + ) + # Start with causal mask + bidirectional_attn_mask.fill_(1) + bidirectional_attn_mask = bidirectional_attn_mask.tril(diagonal=prefix_len) + + # HF only enables bidirectional attention for image tokens, + # not video or audio (see create_causal_mask_mapping). + mm_inputs = forward_batch.mm_inputs[i] + if mm_inputs is not None: + for mm_item in mm_inputs.mm_items: + if mm_item.is_image(): + for im_begin, im_end in mm_item.offsets: + # Note(kpham-sgl): We only apply bidirectional attention when the image token span + # is fully contained in the extend window. Otherwise, we silently fall back to + # causal attention. + # FIXME(kpham-sgl): This is a hack to work around the fact that the image token span + # might not be fully contained in the extend window during chunked prefill. + # We should fix this by properly making chunked prefill mask aware. + if ( + im_begin >= prefix_len + and im_end < prefix_len + extend_seq_len + ): + bidirectional_attn_mask[ + im_begin - prefix_len : im_end + 1 - prefix_len, + im_begin : im_end + 1, + ] = 1 + elif ( + im_end >= prefix_len + and im_begin < prefix_len + extend_seq_len + ): + split_images.append((i, im_begin, im_end)) + + bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) + bidirectional_attn_mask_indptr[i + 1] = ( + bidirectional_attn_mask_indptr[i] + bidirectional_attn_mask.nelement() + ) + if split_images: + num_split_images = len(split_images) + logger.warning_once( + f"{num_split_images} images are split across chunk boundaries. " + "Below are the first 5 images that are split across chunk boundaries: " + ) + for i, im_begin, im_end in split_images[:5]: + logger.warning_once( + f"Image {i}:{im_begin}-{im_end} is split across chunk boundaries.\n", + ) + logger.warning_once( + "Those images will receive causal attention. Disable chunked prefill (--chunked-prefill-size=-1) for full bidirectional attention.", + ) + if bidirectional_attn_masks_list: + bidirectional_attn_masks = torch.cat(bidirectional_attn_masks_list, dim=0) + forward_batch.attn_backend.forward_metadata.mask_indptr = ( + bidirectional_attn_mask_indptr + ) + forward_batch.attn_backend.forward_metadata.custom_mask = ( + bidirectional_attn_masks + ) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + vt = self.vision_tower + + all_embeds = [] + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "image_position_ids", None)] + ) + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: + raise ValueError( + f"pixel_values[{pv_idx}] has no matching image_position_ids. " + "The HF image processor likely renamed this output — " + "update ATTR_NAME_TO_MODALITY in the Gemma4 processor." + ) + pp = all_position_ids[pv_idx] + + # Vision tower expects 3-D (batch, num_patches, ...). + # A single image may arrive as 2-D; add the batch dim if needed. + if pv.dim() == 2: + pv = pv.unsqueeze(0) + if pp.dim() == 2: + pp = pp.unsqueeze(0) + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + pp = pp.to(device=vt.device) + + pooled, pooler_mask = vt(pv, pp) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """Encode video frames through the vision tower with video-specific pooling. + + Each video is (num_frames, num_patches, patch_pixels) with matching + position_ids (num_frames, num_patches, 2). Frames are flattened into + the batch dimension so each frame is encoded independently, then pooled + dynamically based on the input patch count and pooling_kernel_size. + """ + vt = self.vision_tower + + all_embeds = [] + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "video_position_ids", None)] + ) + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: + raise ValueError( + f"pixel_values_videos[{pv_idx}] has no matching video_position_ids." + ) + pp = all_position_ids[pv_idx] + + # HF processor returns 4-D tensors + # (num_videos, num_frames, num_patches, ...) — collapse to + # 3-D (num_frames, num_patches, ...) so each frame is a + # batch element for the vision tower. + if pv.dim() == 4: + pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1]) + if pp.dim() == 4: + pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1]) + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + pp = pp.to(device=vt.device) + + pooled, pooler_mask = vt(pv, pp) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + if self.audio_tower is None: + raise ValueError( + "Audio inputs provided but the model does not have an audio tower." + ) + + all_input_features = flatten_nested_list([item.feature for item in items]) + all_input_features_mask = flatten_nested_list( + [~item.input_features_mask for item in items] + ) + + all_embeds = [] + for input_features, input_features_mask in zip( + all_input_features, all_input_features_mask + ): + if input_features.dim() == 2: + input_features = input_features.unsqueeze(0) + if input_features_mask.dim() == 1: + input_features_mask = input_features_mask.unsqueeze(0) + + input_features = input_features.to( + device=self.audio_tower.device, + dtype=self.language_model.dtype(), + ) + input_features_mask = input_features_mask.to(device=input_features.device) + + # audio_mel_mask convention: True = padding + audio_encodings, audio_mask = self.audio_tower( + input_features, input_features_mask + ) + + audio_features = self.embed_audio(inputs_embeds=audio_encodings) + + for enc, mask in zip(audio_features, audio_mask): + all_embeds.append(enc[~mask]) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_per_layer_inputs( + self, input_ids: torch.LongTensor + ) -> Optional[torch.Tensor]: + return self.language_model.get_per_layer_inputs(input_ids) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs: object, + ) -> LogitsProcessor: + """Forward pass for multimodal Gemma4.""" + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + positions += 1 + per_layer_inputs = None + if input_ids is not None: + ple_ids = input_ids.clone() + pad_id = self.config.text_config.pad_token_id + ple_ids[input_ids == self.config.image_token_id] = pad_id + ple_ids[input_ids == self.config.video_token_id] = pad_id + ple_ids[input_ids == self.config.audio_token_id] = pad_id + per_layer_inputs = self.get_per_layer_inputs(ple_ids) + + # Prepare bidirectional attention masks for image tokens during prefill. + # Gemma 4 uses bidirectional attention for image soft tokens. + # Only TritonAttnBackend supports this; incompatible with CUDA Graph and + # chunked prefill. + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + and forward_batch.contains_image_inputs() + ): + self.prepare_attn_masks( + forward_batch, + input_ids, + mask_dtype=torch.bool, + ) + + # Use general_mm_embed_routine for handling multimodal data + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + Modality.VIDEO: self.get_video_feature, + Modality.AUDIO: self.get_audio_feature, + }, + positions=positions, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + # Process hidden states through logits processor + return self.logits_processor( + input_ids, hidden_states, self.language_model.embed_tokens, forward_batch + ) + + def tie_weights(self, recompute_mapping=False): + return self.language_model.tie_weights() + + # Standard stacked-params mapping for fused QKV / GateUp linears + # in the text decoder. Also consumed by the tower QKV remap (step 2). + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + + # Regex for fused QKV in vision/audio towers. + # Vision: *.self_attn.{q,k,v}_proj.* Audio: *.attn.{q,k,v}_proj.* + _RE_TOWER_QKV = re.compile( + r"(.+\.(?:self_attn|attn))\.(q_proj|k_proj|v_proj)\.(.*)" + ) + # Regex for fused GateUp in the vision tower MLP. + _RE_TOWER_GATE_UP = re.compile(r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)") + + _RE_AUDIO_LAYER = re.compile(r"(audio_tower)\.layers\.(\d+)\.(.*)") + + @staticmethod + def _remap_audio_tower_name(name: str) -> str: + """Remap audio tower checkpoint names to our module tree. + + Checkpoint naming (``layers``, ``self_attn``, ``feed_forward1/2``, etc.) + differs from our module tree (``conformer``, ``attention.attn``, + ``ffw_layer_start/end``, etc.). Applied before ``_remap_tower_name``. + """ + if "audio_tower." not in name: + return name + + # SSCP conv block: layer0/layer1 → conv_0/conv_1 + name = name.replace( + "subsample_conv_projection.layer0.", + "subsample_conv_projection.conv_0.", + ) + name = name.replace( + "subsample_conv_projection.layer1.", + "subsample_conv_projection.conv_1.", + ) + + # Conformer layers: audio_tower.layers.{i} → audio_tower.conformer.{i} + m = Gemma4ForConditionalGeneration._RE_AUDIO_LAYER.match(name) + if m: + tower, layer_idx, suffix = m.groups() + + # Order matters: more specific patterns first. + # relative_k_proj → relative_position_embedding.pos_proj + suffix = suffix.replace( + "self_attn.relative_k_proj.", + "attention.attn.relative_position_embedding.pos_proj.", + ) + # self_attn.post → attention.post (the output projection) + suffix = suffix.replace("self_attn.post.", "attention.post.") + # general self_attn → attention.attn + suffix = suffix.replace("self_attn.", "attention.attn.") + # norms + suffix = suffix.replace("norm_pre_attn.", "attention.pre_attn_norm.") + suffix = suffix.replace("norm_post_attn.", "attention.post_norm.") + suffix = suffix.replace("norm_out.", "norm.") + # feed-forward blocks + suffix = suffix.replace("feed_forward1.", "ffw_layer_start.") + suffix = suffix.replace("feed_forward2.", "ffw_layer_end.") + + name = f"{tower}.conformer.{layer_idx}.{suffix}" + + return name + + @staticmethod + def _remap_tower_name(name: str, params_dict: dict) -> str: + """Remap a vision/audio tower checkpoint name to our module tree. + + Three transformations, applied in order: + + 1. **Fused QKV** — ``{q,k,v}_proj.*`` → ``qkv.*`` + Weight/bias are redirected into the fused ``qkv.{proj}.{attr}`` + namespace (stacked-params then merges them into ``qkv_proj``). + Clip buffers are split: ``input_*`` → shared ``qkv.input_*``, + ``output_*`` → per-projection ``qkv.{q,k,v}_output_*``. + + 2. **Fused GateUp** — ``{gate,up}_proj.*`` → ``gate_up.*`` + Same pattern as QKV. + + 3. **Clippable wrapper** — ``*.weight``/``*.bias`` → ``*.linear.weight`` + Catches the remaining (non-fused) clippable linears whose inner + ``RowParallelLinear``/``ColumnParallelLinear`` lives at ``.linear``. + Falls back to the original name when ``.linear.`` does not exist + in ``params_dict`` (plain linears, norms, conv weights, etc.). + """ + # Step 1: fused QKV + m = Gemma4ForConditionalGeneration._RE_TOWER_QKV.match(name) + if m: + pfx, proj, attr = m.groups() + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.qkv.{proj}.{bare_attr}" + if attr.startswith("output_"): + return f"{pfx}.qkv.{proj[0]}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.qkv.{attr}" + + # Step 2: fused GateUp + m = Gemma4ForConditionalGeneration._RE_TOWER_GATE_UP.match(name) + if m: + pfx, proj, attr = m.groups() + short = proj.split("_")[0] # "gate" or "up" + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.gate_up.{proj}.{bare_attr}" + if attr.startswith("output_"): + return f"{pfx}.gate_up.{short}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.gate_up.{attr}" + + # Step 3: clippable wrapper (.weight → .linear.weight) + if name.endswith(".weight") or name.endswith(".bias"): + base, attr = name.rsplit(".", 1) + alt = f"{base}.linear.{attr}" + if alt in params_dict: + return alt + + return name + + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + text_config = self.config.text_config + if not getattr(text_config, "attention_k_eq_v", False): + return set() + return { + i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention" + } + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + k_eq_v_layers = self._get_k_eq_v_layers() + + num_experts = getattr(self.config.text_config, "num_experts", 0) or 0 + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] + + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "embed_vision.embedding." in name or "embed_audio.embedding." in name: + continue + if self.audio_tower is None and ( + "audio_tower." in name or "embed_audio." in name + ): + continue + + name = re.sub(r"^model\.", "", name) + + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") + + # Remap audio tower checkpoint names to our module tree + if "audio_tower." in name: + name = self._remap_audio_tower_name(name) + + # Remap vision / audio tower names (fused QKV/GateUp, clippable wrappers) + if "vision_tower." in name or "audio_tower." in name: + name = self._remap_tower_name(name, params_dict) + + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and "language_model." in name + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). + orig_name = name + for param_name, weight_name, shard_ids in expert_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) + break + else: + for param_name, weight_name, shard_id in self.stacked_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") + break + else: + name = orig_name + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) + return loaded_params + + lora_pattern = re.compile( + r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self.lora_pattern.match(module_name)) + + def get_hidden_dim(self, module_name, layer_idx): + # return input_dim, output_dim + if module_name == "qkv_proj": + return ( + self.config.hidden_size, + self.config.head_dim + * ( + self.config.num_attention_heads + + self.config.num_key_value_heads * 2 + ), + ) + elif module_name == "o_proj": + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name == "gate_up_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.hidden_size, self.config.intermediate_size[0] * 2 + elif module_name == "down_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.intermediate_size[0], self.config.hidden_size + else: + raise NotImplementedError() + + +EntryClass = Gemma4ForConditionalGeneration diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py new file mode 100644 index 000000000000..f0c49cbc68b8 --- /dev/null +++ b/python/sglang/srt/models/gemma4_vision.py @@ -0,0 +1,599 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import Gemma4VisionConfig + +from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL +from sglang.srt.layers.clippable_linear import ( + ClippableGateUpParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, get_device_capability, is_cuda, is_hip + +# --------------------------------------------------------------------------- +# 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) +# --------------------------------------------------------------------------- + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: + return (x * cos) + (_rotate_half(x) * sin) + + +class Gemma4VisionRotaryEmbedding(nn.Module): + """Compute 2-D multidimensional RoPE cos/sin for patch positions.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + self.rope_theta: float = config.rope_parameters["rope_theta"] + + @torch.no_grad() + def forward( + self, x: torch.Tensor, patch_positions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: [batch, seq, hidden] – only used for device/dtype. + patch_positions: [batch, num_patches, 2] – (x, y) coordinates. + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim]. + """ + ndim = patch_positions.shape[-1] # 2 + head_dim_per_dim = self.head_dim // ndim + + all_embs = [] + for d in range(ndim): + dim_inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange( + 0, head_dim_per_dim, 2, device=x.device, dtype=torch.float + ) + / head_dim_per_dim + ) + ) + dim_inv_freq_expanded = dim_inv_freq[None, :, None].expand( + patch_positions.shape[0], -1, 1 + ) + dim_positions = patch_positions[:, :, d].float() + dim_positions_expanded = dim_positions[:, None, :] + + dim_freqs = (dim_inv_freq_expanded @ dim_positions_expanded).transpose(1, 2) + dim_emb = torch.cat((dim_freqs, dim_freqs), dim=-1) + all_embs.append(dim_emb) + + emb = torch.cat(all_embs, dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + return cos, sin + + +def _apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """Apply 2-D RoPE to x of shape [batch*seq, heads, head_dim]. + + cos/sin have shape [batch, seq, head_dim]. We split along head_dim into + ndim=2 parts and apply standard rotary to each independently. + """ + ndim = 2 + chunk_size = x.shape[-1] // ndim + x_parts = x.split(chunk_size, dim=-1) + cos_parts = cos.split(chunk_size, dim=-1) + sin_parts = sin.split(chunk_size, dim=-1) + y_parts = [ + _apply_rotary(x_parts[k], cos_parts[k], sin_parts[k]) for k in range(ndim) + ] + return torch.cat(y_parts, dim=-1) + + +# --------------------------------------------------------------------------- +# Vision Attention (TP-sharded, fused QKV) +# --------------------------------------------------------------------------- + + +class Gemma4VisionAttention(nn.Module): + """Multi-head attention for the Gemma 4 vision encoder. + + QKV uses a fused ``ClippableQKVParallelLinear`` for efficient matmul with + per-projection clip bounds. Output projection uses ``ClippableLinear``. + """ + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.head_dim = config.head_dim + + tp_size = get_attention_tp_size() + self.num_heads_per_partition = config.num_attention_heads // tp_size + self.num_kv_heads_per_partition = config.num_key_value_heads // tp_size + + self.qkv = ClippableQKVParallelLinear( + hidden_size=config.hidden_size, + head_size=config.head_dim, + total_num_heads=config.num_attention_heads, + total_num_kv_heads=config.num_key_value_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=prefix, + ) + self.o_proj = ClippableRowParallelLinear( + input_size=config.num_attention_heads * config.head_dim, + output_size=config.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm( + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False + ) + + backend = self._select_backend() + self.qkv_backend = QKV_BACKEND_IMPL[backend]( + head_dim=config.head_dim, + num_heads=self.num_heads_per_partition, + num_kv_heads=self.num_kv_heads_per_partition, + dropout=0.0, + flatten_batch=True, + softmax_in_single_precision=False, + softmax_scale=1.0, + ) + + @staticmethod + def _select_backend() -> str: + """Mirror VisionAttention._determine_attention_backend for consistency.""" + from sglang.srt.server_args import get_global_server_args + + override = get_global_server_args().mm_attention_backend + if override is not None: + return override + if is_cuda(): + major, _ = get_device_capability() + if major == 9: + from sglang.srt.utils import is_blackwell_supported + + if is_blackwell_supported(): + return "triton_attn" + return "fa3" + return "triton_attn" + if is_hip(): + # ROCm: use triton_attn to avoid SDPA flatten_batch issues + # with multi-image/video inputs + return "triton_attn" + return "sdpa" + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + + q, k, v = self.qkv(hidden_states) + + q = q.reshape(bsz * seq_len, self.num_heads_per_partition, self.head_dim) + k = k.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + v = v.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + + q = self.q_norm(q.reshape(-1, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.head_dim)).reshape(k.shape) + v = self.v_norm(v.reshape(-1, self.head_dim)).reshape(v.shape) + + cos_flat = cos.reshape(bsz * seq_len, 1, self.head_dim) + sin_flat = sin.reshape(bsz * seq_len, 1, self.head_dim) + q = _apply_multidimensional_rope(q, cos_flat, sin_flat) + k = _apply_multidimensional_rope(k, cos_flat, sin_flat) + + if attention_mask is not None: + attn_mask_4d = ( + attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(1) + ).unsqueeze(1) + else: + attn_mask_4d = None + + output = self.qkv_backend.forward( + q=q, + k=k, + v=v, + cu_seqlens=None, + bsz=bsz, + seq_len=seq_len, + attention_mask=attn_mask_4d, + softmax_scale=1.0, + ) + + output = rearrange(output, "(b s) h d -> b s (h d)", b=bsz) + output = self.o_proj(output) + return output + + +# --------------------------------------------------------------------------- +# Vision MLP (GatedGELU, TP-sharded) +# --------------------------------------------------------------------------- + + +class Gemma4VisionMLP(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if config.hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + f"Gemma4VisionMLP expects hidden_activation='gelu_pytorch_tanh', " + f"got {config.hidden_activation!r}" + ) + self.gate_up = ClippableGateUpParallelLinear( + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + self.down_proj = ClippableRowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate, up = self.gate_up(x) + x = F.gelu(gate, approximate="tanh") * up + x = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# Encoder Layer +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoderLayer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Gemma4VisionAttention( + config, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = Gemma4VisionMLP( + config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + eps = config.rms_norm_eps + hs = config.hidden_size + self.input_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_attention_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + + self.register_buffer("layer_scalar", torch.ones(())) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Vision Transformer (stack of encoder layers + RoPE) +# --------------------------------------------------------------------------- + + +class Gemma4VisionTransformer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [ + Gemma4VisionEncoderLayer( + config, + layer_idx=i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + patch_positions: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + inputs_embeds: [batch, seq, hidden_size] + attention_mask: [batch, seq] — True = valid token + patch_positions: [batch, seq, 2] + Returns: + last_hidden_state: [batch, seq, hidden_size] + """ + cos, sin = self.rotary_emb(inputs_embeds, patch_positions) + hidden_states = inputs_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, cos, sin, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# Patch Embedder +# --------------------------------------------------------------------------- + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.position_embedding_size = config.position_embedding_size + + self.input_proj = nn.Linear( + 3 * self.patch_size**2, self.hidden_size, bias=False + ) + self.position_embedding_table = nn.Parameter( + torch.ones(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, patch_positions: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + clamped_positions = patch_positions.clamp(min=0) + one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + position_embeddings = one_hot @ self.position_embedding_table + position_embeddings = position_embeddings.sum(dim=1) + position_embeddings = torch.where( + padding_positions.unsqueeze(-1), 0.0, position_embeddings + ) + return position_embeddings + + def _patch_projection(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Project pre-patchified pixels into model space. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — already patchified + by the image processor, values in [0, 1]. + """ + patches = 2 * (pixel_values - 0.5) + return self.input_proj(patches.to(self.input_proj.weight.dtype)) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + ) -> torch.Tensor: + """Compute patch embeddings with positional information. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + padding_positions: [batch, num_patches] — True for padding patches. + """ + hidden_states = self._patch_projection(pixel_values) + position_embeddings = self._position_embeddings( + pixel_position_ids, padding_positions + ) + return hidden_states + position_embeddings + + +# --------------------------------------------------------------------------- +# Pooler +# --------------------------------------------------------------------------- + + +class Gemma4VisionPooler(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, x: torch.Tensor, patch_positions: torch.Tensor, length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_seq_len = x.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k**2 + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {x.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." + ) + clamped_positions = patch_positions.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2).to(x.dtype) @ x + mask = torch.logical_not((weights == 0).all(dim=1)) + return output, mask + + def forward( + self, + hidden_states: torch.Tensor, + patch_positions: torch.Tensor, + padding_positions: torch.Tensor, + output_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + (pooled_hidden_states, mask) where mask is True for valid tokens. + """ + if output_length is None: + raise ValueError("output_length is required for Gemma4VisionPooler") + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are patches" + f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." + ) + length = output_length + if isinstance(length, (list, tuple)): + length = length[0] + if hidden_states.shape[1] == length: + mask = padding_positions + else: + hidden_states, mask = self._avg_pool_by_positions( + hidden_states, patch_positions, length + ) + hidden_states = hidden_states * self.root_hidden_size + return hidden_states, mask + + +# --------------------------------------------------------------------------- +# Top-level Vision Encoder (patch_embedder → transformer → pooler) +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoder(nn.Module): + """Drop-in replacement for HF ``Gemma4VisionEncoder`` with TP support.""" + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.pooling_kernel_size = config.pooling_kernel_size + + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionTransformer( + config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + self.pooler = Gemma4VisionPooler(config) + + # Post-pooling standardization (normalizes vision tokens before projection) + self.standardize = getattr(config, "standardize", False) + if self.standardize: + self.register_buffer("std_bias", torch.zeros(config.hidden_size)) + self.register_buffer("std_scale", torch.ones(config.hidden_size)) + + @property + def device(self) -> torch.device: + return self.patch_embedder.input_proj.weight.device + + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode pre-patchified pixel_values into soft tokens. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified + by the image processor. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + + Returns: + (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], + pooler_mask [batch, output_len] True = valid. + """ + k2 = self.pooling_kernel_size * self.pooling_kernel_size + output_length = pixel_values.shape[-2] // k2 + + padding_positions = (pixel_position_ids == -1).all(dim=-1) + + inputs_embeds = self.patch_embedder( + pixel_values, pixel_position_ids, padding_positions + ) + + last_hidden = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + patch_positions=pixel_position_ids, + ) + + pooled, pooler_mask = self.pooler( + last_hidden, + pixel_position_ids, + padding_positions, + output_length=output_length, + ) + + if self.standardize: + pooled = (pooled - self.std_bias) * self.std_scale + + return pooled, pooler_mask diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 839d5b74e079..254e019ffcce 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -386,6 +386,7 @@ def process_mm_data( if audios: if self._processor.__class__.__name__ in { "Gemma3nProcessor", + "Gemma4Processor", "GlmAsrProcessor", "Qwen2AudioProcessor", "Qwen3OmniMoeProcessor", diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py new file mode 100644 index 000000000000..80bb37061358 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -0,0 +1,145 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch + +from sglang.srt.managers.multimodal_processor import ( + BaseMultimodalProcessor as SGLangBaseProcessor, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalProcessorOutput +from sglang.srt.models.gemma4_audio import _SSCP_CONV_STRIDE_SIZES +from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens +from sglang.srt.utils.video_decoder import VideoDecoderWrapper + + +class Gemma4SGLangProcessor(SGLangBaseProcessor): + """Multimodal processor for Gemma4 supporting image, video, and audio inputs.""" + + models = [Gemma4ForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + self.IM_START_TOKEN_ID = hf_config.boi_token_id + self.IM_END_TOKEN_ID = hf_config.eoi_token_id + + self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id + self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + audio_token_id=hf_config.audio_token_id, + ).build(_processor) + + # Register image-processor and video-processor outputs so they are stored on + # MultimodalDataItem via collect_mm_items_from_processor_output. + self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE + self.ATTR_NAME_TO_MODALITY["video_position_ids"] = Modality.VIDEO + + def _get_audio_pad_multiple(self) -> int: + """Derive the waveform padding alignment from processor config. + + The HF processor's ceil(duration_ms / audio_ms_per_token) formula can + overshoot by 1 token relative to what the SSCP convolutions produce. + Padding waveforms to a multiple of (hop_length * first_conv_stride) + aligns the two calculations. + See: gemma-4-eap-extras/examples/gemma-4-audio-examples.ipynb + """ + fe = getattr(self._processor, "feature_extractor", None) + hop = getattr(fe, "hop_length", 160) + first_stride = _SSCP_CONV_STRIDE_SIZES[0][0] + return hop * first_stride + + def _video_decoder_to_tensor(self, vdw: VideoDecoderWrapper) -> torch.Tensor: + """Convert a VideoDecoderWrapper to a (sampled_frames, C, H, W) uint8 tensor. + + SGLang's load_video returns VideoDecoderWrapper which the HF + Gemma4VideoProcessor does not recognise (expects torch.Tensor or + np.ndarray). We replicate HF's uniform frame sampling here to + avoid materialising the entire video in memory, then delegate the + rest (resize, patchify, position IDs) to the HF video processor. + """ + total = len(vdw) + num_frames = getattr( + getattr(self._processor, "video_processor", None), + "num_frames", + 32, + ) + if total <= num_frames: + indices = list(range(total)) + else: + indices = torch.arange(0, total, total / num_frames).int().tolist() + frames_np = vdw.get_frames_at(indices) # (N, H, W, C) + return torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous() + + def process_mm_data( + self, input_text, images=None, videos=None, audios=None, **kwargs + ): + if audios: + pad_multiple = self._get_audio_pad_multiple() + padded = [] + for a in audios: + a = np.asarray(a) + remainder = len(a) % pad_multiple + if remainder != 0: + a = np.pad(a, (0, pad_multiple - remainder), mode="constant") + padded.append(a) + audios = padded + if videos: + videos = [ + ( + self._video_decoder_to_tensor(v) + if isinstance(v, VideoDecoderWrapper) + else v + ) + for v in videos + ] + kwargs.setdefault("do_sample_frames", False) + return super().process_mm_data( + input_text, images=images, videos=videos, audios=audios, **kwargs + ) + + async def process_mm_data_async( + self, + image_data: Optional[List[Union[str, bytes, Dict]]] = None, + audio_data: Optional[List[Union[str, bytes, Dict]]] = None, + input_text: str = "", + request_obj=None, + *args, + **kwargs, + ): + """Process multimodal data including images, video, and audio.""" + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + video_data=request_obj.video_data if request_obj else None, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return MultimodalProcessorOutput( + input_ids=input_ids.tolist(), + mm_items=mm_items, + im_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, + audio_token_id=self.mm_tokens.audio_token_id, + ) diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index c3dbb3116464..8811c90b2ddc 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -37,6 +37,7 @@ def __init__( self._buffer = "" self.stripped_think_start = False + self.think_start_self_label = "" self.continue_final_message = continue_final_message if self.continue_final_message: @@ -62,7 +63,9 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=text) # The text is considered to be in a reasoning block. - processed_text = text.replace(self.think_start_token, "").strip() + processed_text = text.replace( + self.think_start_token + self.think_start_self_label, "" + ).strip() if ( self.think_end_token not in processed_text @@ -111,8 +114,10 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: self._buffer += new_text current_text = self._buffer + think_start_text = self.think_start_token + self.think_start_self_label + # If the current text is a prefix of the think token, keep buffering - tokens_to_check = [self.think_start_token, self.think_end_token] + tokens_to_check = [think_start_text, self.think_end_token] if self.tool_start_token: tokens_to_check.append(self.tool_start_token) if any( @@ -122,8 +127,8 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: return StreamingParseResult() # Strip `` token if present - if not self.stripped_think_start and self.think_start_token in current_text: - current_text = current_text.replace(self.think_start_token, "") + if not self.stripped_think_start and think_start_text in current_text: + current_text = current_text.replace(think_start_text, "", 1) self.stripped_think_start = True self._in_reasoning = True @@ -477,6 +482,27 @@ def __init__( ) +class Gemma4Detector(BaseReasoningFormatDetector): + """Gemma4 reasoning detector.""" + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "<|channel>", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + self.think_start_self_label = "thought\n" + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -505,6 +531,7 @@ class ReasoningParser: "mistral": MistralDetector, "nemotron_3": Nemotron3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d91ced805f5d..b6a1992a036d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1875,6 +1875,10 @@ def _handle_model_specific_adjustments(self): f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." ) self.disable_hybrid_swa_memory = True + elif model_arch == "Gemma4ForConditionalGeneration": + if self.is_attention_backend_not_set(): + self.attention_backend = "triton" + logger.info("Use triton as default attention backend for Gemma4") elif model_arch in ["Exaone4ForCausalLM", "ExaoneMoEForCausalLM"]: if hf_config.sliding_window_pattern is not None: logger.warning( diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 22a530c7c7a8..b928b08d4d79 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -622,6 +622,32 @@ def get_config( if config.model_type == "multi_modality": config.update({"architectures": ["MultiModalityCausalLM"]}) + if config.model_type == "gemma4": + # Gemma4 configs use base attributes for SWA layers and `global_*` + # variants for full-attention layers. SGLang expects the opposite: + # base = full-attention, `swa_*` = sliding-window overrides. + # Remap here so the rest of the stack sees a uniform convention. + text_config = config.text_config + global_head_dim = getattr(text_config, "global_head_dim", None) + global_kv_heads = getattr(text_config, "num_global_key_value_heads", None) + + swa_head_dim = text_config.head_dim + swa_kv_heads = text_config.num_key_value_heads + + text_config.swa_head_dim = swa_head_dim + text_config.swa_v_head_dim = swa_head_dim + text_config.swa_num_key_value_heads = swa_kv_heads + + if global_head_dim is not None: + text_config.head_dim = global_head_dim + if global_kv_heads is not None: + text_config.num_key_value_heads = global_kv_heads + + if not hasattr(text_config, "v_head_dim"): + text_config.v_head_dim = text_config.head_dim + if not hasattr(text_config, "swa_v_head_dim"): + text_config.swa_v_head_dim = text_config.swa_head_dim + if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index c418b0866d0e..01aa99904072 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -6,6 +6,12 @@ from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import ( + Gemma4Detector, + _parse_gemma4_args, + _parse_gemma4_array, + _parse_gemma4_value, +) from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -4008,5 +4014,340 @@ def test_streaming_multiple_tool_calls_char_by_char_separator(self): self.assertEqual(cities, ["NYC", "LA"]) +class TestGemma4Detector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ) + ] + self.detector = Gemma4Detector() + + def test_detect_and_parse(self): + text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + + def test_parse_streaming_increment(self): + chunks = [ + "Some text ", + "before <|tool", + "_call>call:get_we", + "ather{location:<|", # codespell:ignore + '"|>Tokyo<|"|>} after", + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + combined_normal_text = "".join(r.normal_text for r in all_results) + self.assertEqual(combined_normal_text, "Some text before after") + + found_name = False + found_params = False + for res in all_results: + for call in res.calls: + if call.name == "get_weather": + found_name = True + if call.parameters: + params = json.loads(call.parameters) + if params == {"location": "Tokyo"}: + found_params = True + + self.assertTrue(found_name) + self.assertTrue(found_params) + + def test_nested_array_streaming(self): + # Additional coverage for complex structure + chunks = [ + '<|tool_call>call:get_weather{location:<|"', + '|>New York<|"|>,nested:[1, 2, {inner:<|"|>', + 'val<|"|>}]}', + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + found_params = False + for res in all_results: + for call in res.calls: + if call.parameters: + params = json.loads(call.parameters) + if "location" in params and params["location"] == "New York": + if "nested" in params and params["nested"] == [ + 1, + 2, + {"inner": "val"}, + ]: + found_params = True + + self.assertTrue(found_params) + + def test_has_tool_call(self): + self.assertTrue( + self.detector.has_tool_call( + '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ) + ) + self.assertFalse(self.detector.has_tool_call("no tool call here")) + + def test_detect_and_parse_no_tool_call(self): + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_tool_index(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, 0) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_unknown_tool_index(self): + text = '<|tool_call>call:unknown_func{arg:<|"|>val<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, -1) + + def test_detect_and_parse_nested_object(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,details:{temp:25,unit:<|"|>celsius<|"|>}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertIsInstance(params["details"], dict) + self.assertEqual(params["details"]["temp"], 25) + self.assertEqual(params["details"]["unit"], "celsius") + + def test_detect_and_parse_multiple_calls(self): + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + text = ( + 'Some text <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ' more text <|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}' + ) + result = self.detector.detect_and_parse(text, extra_tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "get_time") + self.assertEqual(result.normal_text, "Some text ") + + def test_parse_gemma4_args_empty(self): + self.assertEqual(_parse_gemma4_args(""), {}) + self.assertEqual(_parse_gemma4_args(" "), {}) + + def test_parse_gemma4_args_booleans(self): + result = _parse_gemma4_args("flag:true,other:false") + self.assertIs(result["flag"], True) + self.assertIs(result["other"], False) + + def test_parse_gemma4_args_numbers(self): + result = _parse_gemma4_args("count:42,ratio:3.14") + self.assertEqual(result["count"], 42) + self.assertAlmostEqual(result["ratio"], 3.14) + + def test_parse_gemma4_args_string_with_colon(self): + result = _parse_gemma4_args('url:<|"|>http://example.com<|"|>') + self.assertEqual(result["url"], "http://example.com") + + def test_parse_gemma4_args_nested_object(self): + result = _parse_gemma4_args('outer:{inner:<|"|>val<|"|>,num:5}') + self.assertIsInstance(result["outer"], dict) + self.assertEqual(result["outer"]["inner"], "val") + self.assertEqual(result["outer"]["num"], 5) + + def test_parse_gemma4_array_mixed_types(self): + result = _parse_gemma4_array('<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}') + self.assertEqual(result[0], "hello") + self.assertEqual(result[1], 42) + self.assertIs(result[2], True) + self.assertIsInstance(result[3], dict) + self.assertEqual(result[3]["key"], "val") + + def test_parse_gemma4_value_types(self): + self.assertIs(_parse_gemma4_value("true"), True) + self.assertIs(_parse_gemma4_value("false"), False) + self.assertEqual(_parse_gemma4_value("42"), 42) + self.assertAlmostEqual(_parse_gemma4_value("3.14"), 3.14) + self.assertEqual(_parse_gemma4_value("hello"), "hello") + self.assertEqual(_parse_gemma4_value(""), "") + + def _collect_streaming(self, chunks): + """Helper: feed chunks and collect normal text + tool calls by index.""" + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + return normal_text, tool_calls_by_index + + def test_streaming_multiple_tool_calls(self): + """Test streaming with two consecutive tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + '<|tool_call>call:get_weather{location:<|"|>', + 'Tokyo<|"|>}', + ' <|tool_call>call:get_time{timezone:<|"|>', + 'UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Tokyo") + self.assertEqual(params1["timezone"], "UTC") + + def test_streaming_very_small_chunks(self): + """Test streaming with character-by-character chunks.""" + full_text = '<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}' + chunks = list(full_text) + + normal_text, tool_calls = self._collect_streaming(chunks) + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["location"], "Rome") + + def test_streaming_empty_args(self): + """Test streaming a tool call with no arguments.""" + chunks = ["<|tool_call>call:get_weather{}", ""] + normal_text, tool_calls = self._collect_streaming(chunks) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + + def test_streaming_text_between_tool_calls(self): + """Test streaming with normal text interleaved between two different tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + "Hello! ", + '<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}', + " Let me also check ", + '<|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + self.assertIn("Hello!", normal_text) + self.assertIn("Let me also check", normal_text) + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Paris") + self.assertEqual(params1["timezone"], "UTC") + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/parser/test_reasoning_parser.py b/test/registered/unit/parser/test_reasoning_parser.py index 5b9d623d51b7..8f05d7903e9b 100644 --- a/test/registered/unit/parser/test_reasoning_parser.py +++ b/test/registered/unit/parser/test_reasoning_parser.py @@ -5,6 +5,7 @@ from sglang.srt.parser.reasoning_parser import ( BaseReasoningFormatDetector, DeepSeekR1Detector, + Gemma4Detector, Glm45Detector, KimiDetector, KimiK2Detector, @@ -586,6 +587,141 @@ def test_force_nonempty_content_no_thinking_tokens(self): self.assertEqual(result.reasoning_text, "") +class TestGemma4Detector(CustomTestCase): + def setUp(self): + self.detector = Gemma4Detector() + + def test_init(self): + """Test Gemma4Detector initialization.""" + self.assertEqual(self.detector.think_start_token, "<|channel>") + self.assertEqual(self.detector.think_end_token, "") + self.assertEqual(self.detector.think_start_self_label, "thought\n") + self.assertFalse(self.detector._in_reasoning) + self.assertTrue(self.detector.stream_reasoning) + + def test_detect_and_parse_complete_reasoning(self): + """Test parsing complete Gemma4 reasoning block (think_start_self_label is stripped).""" + text = "<|channel>thought\nLet me think about thisThe answer is 42." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Let me think about this") + self.assertEqual(result.normal_text, "The answer is 42.") + + def test_detect_and_parse_without_thinking(self): + """Test parsing without thinking (enable_thinking=False case).""" + text = "Direct answer without thinking." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.normal_text, text) + self.assertEqual(result.reasoning_text, "") + + def test_detect_and_parse_reasoning_only(self): + """Test parsing when output is all reasoning (no end token yet).""" + text = "<|channel>thought\nStill thinking..." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Still thinking...") + self.assertEqual(result.normal_text, "") + + def test_streaming_complete_flow(self): + """Test streaming parse of Gemma4 reasoning flow.""" + chunks = [ + "<|channel>", + "thought\nreasoning content", + "", + "final answer", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + all_normal += result.normal_text + self.assertIn("reasoning content", all_reasoning) + self.assertIn("final answer", all_normal) + + def test_streaming_full_start_sequence(self): + """Test streaming with the full start sequence (token + self_label).""" + # Gemma4 start sequence is "<|channel>thought\n", not just "<|channel>" + result = self.detector.parse_streaming_increment("<|channel>thought\n") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + self.assertTrue(self.detector._in_reasoning) + + result = self.detector.parse_streaming_increment("reasoning content") + self.assertEqual(result.reasoning_text, "reasoning content") + self.assertEqual(result.normal_text, "") + + def test_streaming_partial_start_buffered(self): + """Test that partial start sequence is buffered.""" + # "<|channel>" alone is a prefix of "<|channel>thought\n", so it's buffered + result = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + + def test_streaming_end_token_mid_chunk(self): + """Test end token arriving in the same chunk as reasoning content.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + result = self.detector.parse_streaming_increment( + "some reasoningthe answer" + ) + self.assertEqual(result.reasoning_text, "some reasoning") + self.assertEqual(result.normal_text, "the answer") + self.assertFalse(self.detector._in_reasoning) + + def test_streaming_split_end_token(self): + """Test end token split across two chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + self.detector.parse_streaming_increment("reasoning content") + + result1 = self.detector.parse_streaming_increment("final answer") + self.assertFalse(self.detector._in_reasoning) + self.assertIn("final answer", result2.normal_text) + + def test_streaming_self_label_split_across_chunks(self): + """Test self_label ('thought\\n') arriving separately from start token.""" + result1 = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result1.reasoning_text, "") + self.assertEqual(result1.normal_text, "") + + result2 = self.detector.parse_streaming_increment("thought\n") + self.assertTrue(self.detector._in_reasoning) + + result3 = self.detector.parse_streaming_increment("reasoning here") + self.assertEqual(result3.reasoning_text, "reasoning here") + + def test_streaming_force_reasoning(self): + """Test streaming with force_reasoning=True (no start token needed).""" + detector = Gemma4Detector(force_reasoning=True) + + result1 = detector.parse_streaming_increment("reasoning content") + self.assertEqual(result1.reasoning_text, "reasoning content") + self.assertEqual(result1.normal_text, "") + + result2 = detector.parse_streaming_increment("the answer") + self.assertFalse(detector._in_reasoning) + self.assertIn("the answer", result2.normal_text) + + def test_streaming_multiple_reasoning_chunks(self): + """Test reasoning content arriving in many small chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + + all_reasoning = "" + for chunk in ["Think", "ing ", "step ", "by ", "step."]: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + self.assertEqual(result.normal_text, "") + self.assertEqual(all_reasoning, "Thinking step by step.") + + def test_force_reasoning(self): + """Test Gemma4Detector with force_reasoning=True.""" + detector = Gemma4Detector(force_reasoning=True) + text = "This should be reasoningThe answer." + result = detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "This should be reasoning") + self.assertEqual(result.normal_text, "The answer.") + + class TestReasoningParser(CustomTestCase): def test_init_valid_model(self): """Test initialization with valid model types.""" @@ -604,6 +740,9 @@ def test_init_valid_model(self): parser = ReasoningParser("glm45") self.assertIsInstance(parser.detector, Glm45Detector) + parser = ReasoningParser("gemma4") + self.assertIsInstance(parser.detector, Gemma4Detector) + def test_init_invalid_model(self): """Test initialization with invalid model type.""" with self.assertRaises(ValueError) as context: @@ -782,6 +921,35 @@ def test_kimi_streaming_scenario(self): self.assertIn("multiple factors", all_reasoning) self.assertIn("42", all_normal) + def test_gemma4_complete_response(self): + """Test complete Gemma4 response parsing (think_start_self_label stripped).""" + parser = ReasoningParser("gemma4") + text = "<|channel>thought\nI need to solve x + 2 = 5. Subtracting 2: x = 3.The answer is x = 3." + reasoning, normal = parser.parse_non_stream(text) + self.assertIn("x = 3", reasoning) + self.assertNotIn("thought\n", reasoning) + self.assertEqual(normal, "The answer is x = 3.") + + def test_gemma4_streaming_scenario(self): + """Test Gemma4 streaming scenario.""" + parser = ReasoningParser("gemma4") + chunks = [ + "<|channel>", + "thought\nLet me analyze.", + " Multiple factors.", + "", + "The solution is 42.", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + reasoning, normal = parser.parse_stream_chunk(chunk) + all_reasoning += reasoning + all_normal += normal + self.assertIn("analyze", all_reasoning) + self.assertIn("Multiple factors", all_reasoning) + self.assertIn("42", all_normal) + def test_empty_reasoning_blocks(self): """Test handling of empty reasoning blocks.""" parser = ReasoningParser("qwen3")