diff --git a/.codespellrc b/.codespellrc index 808a344b4e6f..5b14597698f4 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, ather skip = *.json,*.jsonl,*.patch,*.txt diff --git a/.github/workflows/nightly-test-amd-rocm720.yml b/.github/workflows/nightly-test-amd-rocm720.yml index 14929952ebd6..272972077769 100644 --- a/.github/workflows/nightly-test-amd-rocm720.yml +++ b/.github/workflows/nightly-test-amd-rocm720.yml @@ -621,7 +621,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # 8-GPU Qwen 3.5 (Accuracy) ROCm 7.2 + # 8-GPU Qwen 3.5 (Accuracy + Performance combined) ROCm 7.2 nightly-8-gpu-qwen35-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen35-rocm720,')) runs-on: linux-mi325-8gpu-sglang @@ -653,6 +653,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test ROCm 7.2 (8-GPU Qwen 3.5 FP8) + timeout-minutes: 120 + continue-on-error: true + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-qwen35-fp8 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + # 8-GPU GLM-5 (Accuracy) ROCm 7.2 nightly-8-gpu-glm5-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-glm5-rocm720,')) @@ -1219,7 +1231,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # MI35x 8-GPU Qwen 3.5 (Accuracy) ROCm 7.2 + # MI35x 8-GPU Qwen 3.5 (Accuracy + Performance combined) ROCm 7.2 nightly-8-gpu-mi35x-qwen35-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen35-rocm720,')) runs-on: linux-mi35x-gpu-8 @@ -1252,6 +1264,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test MI35x ROCm 7.2 (8-GPU Qwen 3.5 FP8) + timeout-minutes: 120 + continue-on-error: true + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-qwen35-fp8 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + nightly-8-gpu-mi35x-glm5-rocm720: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-glm5-rocm720,')) runs-on: linux-mi35x-gpu-8 diff --git a/.github/workflows/nightly-test-amd.yml b/.github/workflows/nightly-test-amd.yml index 64cca74d7e0f..702ec1d94085 100644 --- a/.github/workflows/nightly-test-amd.yml +++ b/.github/workflows/nightly-test-amd.yml @@ -624,7 +624,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # 8-GPU Qwen 3.5 (Accuracy) + # 8-GPU Qwen 3.5 (Accuracy + Performance combined) nightly-8-gpu-qwen35: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-qwen35,')) runs-on: linux-mi325-8gpu-sglang @@ -656,6 +656,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test (8-GPU Qwen 3.5 FP8) + timeout-minutes: 120 + continue-on-error: true + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-qwen35-fp8 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + nightly-8-gpu-glm5: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-glm5,')) runs-on: linux-mi325-8gpu-sglang @@ -1224,7 +1236,7 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} - # MI35x 8-GPU Qwen 3.5 (Accuracy) + # MI35x 8-GPU Qwen 3.5 (Accuracy + Performance combined) nightly-8-gpu-mi35x-qwen35: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-qwen35,')) runs-on: linux-mi35x-gpu-8 @@ -1257,6 +1269,18 @@ jobs: echo "$(> $GITHUB_STEP_SUMMARY || true exit ${TEST_EXIT_CODE:-0} + - name: Performance Test MI35x (8-GPU Qwen 3.5 FP8) + timeout-minutes: 120 + continue-on-error: true + run: | + > github_summary.md # Clear summary file + bash scripts/ci/amd/amd_ci_exec.sh -w /sglang-checkout/test \ + -e SGLANG_USE_AITER=1 \ + -e GITHUB_STEP_SUMMARY="/sglang-checkout/github_summary.md" \ + python3 run_suite.py --hw amd --suite nightly-perf-8-gpu-mi35x-qwen35-fp8 --nightly --timeout-per-file 5400 ${{ inputs.continue_on_error && '--continue-on-error' || '' }} || TEST_EXIT_CODE=$? + echo "$(> $GITHUB_STEP_SUMMARY || true + exit ${TEST_EXIT_CODE:-0} + nightly-8-gpu-mi35x-glm5: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && (!(inputs.job_filter || inputs.job_select) || (inputs.job_filter || inputs.job_select) == 'all' || contains(format(',{0},', inputs.job_filter || inputs.job_select), ',nightly-8-gpu-mi35x-glm5,')) runs-on: linux-mi35x-gpu-8 diff --git a/benchmark/asr/README.md b/benchmark/asr/README.md index 0acbf1c30fae..5c16490e9262 100644 --- a/benchmark/asr/README.md +++ b/benchmark/asr/README.md @@ -6,6 +6,8 @@ This benchmark evaluates the performance and accuracy (Word Error Rate - WER) of - `openai/whisper-large-v3` - `openai/whisper-large-v3-turbo` +- `Qwen/Qwen3-ASR-1.7B` +- `Qwen/Qwen3-ASR-0.6B` ## Setup diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index 37a9607b6014..d08d2bb75d83 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -134,6 +134,10 @@ def get_model_config( topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size hidden_size = getattr(config, "moe_latent_size", None) or hidden_size + elif architecture == "Gemma4ForConditionalGeneration": + E = config.num_experts // ep_size + topk = config.top_k_experts + intermediate_size = config.moe_intermediate_size else: # Default: Mixtral E = config.num_local_experts // ep_size diff --git a/benchmark/mmlu/bench_hf.py b/benchmark/mmlu/bench_hf.py new file mode 100644 index 000000000000..c76a18db685b --- /dev/null +++ b/benchmark/mmlu/bench_hf.py @@ -0,0 +1,151 @@ +""" +Usage: +python3 bench_hf.py --model-path meta-llama/Llama-2-7b-hf --data-dir data --ntrain 5 +""" + +import argparse +import json +import os +import time + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + l = subject.split("_") + s = "" + for entry in l: + s += " " + entry + return s + + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def main(args): + print(f"Loading model: {args.model_path}") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map="auto", + ).eval() + + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(args.data_dir, "test")) + if "_test.csv" in f + ] + ) + + all_cors = [] + num_requests = 0 + total_latency = 0 + + for subject in tqdm(subjects[: args.nsub]): + dev_df = pd.read_csv( + os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None + )[: args.ntrain] + test_df = pd.read_csv( + os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None + ) + + k = args.ntrain + few_shot_examples = gen_prompt(dev_df, subject, k) + while len(tokenizer.encode(few_shot_examples)) > 1536: + k -= 1 + if k < 0: + break + few_shot_examples = gen_prompt(dev_df, subject, k) + + preds = [] + labels = [] + tic = time.perf_counter() + + for i in range(test_df.shape[0]): + prompt_end = format_example(test_df, i, include_answer=False) + prompt = few_shot_examples + prompt_end + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + output_ids = model.generate( + input_ids, + max_new_tokens=1, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + output_str = tokenizer.decode( + output_ids[0][input_ids.shape[-1] :], skip_special_tokens=True + ) + preds.append(output_str.strip()[0] if len(output_str.strip()) > 0 else "") + labels.append(test_df.iloc[i, test_df.shape[1] - 1]) + + latency = time.perf_counter() - tic + total_latency += latency + + cors = [pred == label for pred, label in zip(preds, labels)] + all_cors.append(cors) + num_requests += len(test_df) + + print( + f"Subject: {subject}, Accuracy: {np.mean(cors):.3f}, Latency: {latency:.3f}s" + ) + + weighted_acc = np.mean(np.concatenate(all_cors)) + print(f"Total Latency: {total_latency:.3f}s") + print(f"Average Accuracy: {weighted_acc:.3f}") + + if args.output: + with open(args.output, "a") as fout: + value = { + "task": "mmlu", + "backend": "hf", + "model": args.model_path, + "latency": round(total_latency, 3), + "accuracy": round(weighted_acc, 3), + "num_requests": num_requests, + "other": { + "nsub": args.nsub, + "ntrain": args.ntrain, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--ntrain", type=int, default=5) + parser.add_argument("--data-dir", type=str, default="data") + parser.add_argument("--nsub", type=int, default=60) + parser.add_argument("--output", type=str, help="Output file path") + args = parser.parse_args() + main(args) diff --git a/docs/advanced_features/hisparse_guide.md b/docs/advanced_features/hisparse_guide.md new file mode 100644 index 000000000000..4882d64f6b99 --- /dev/null +++ b/docs/advanced_features/hisparse_guide.md @@ -0,0 +1,111 @@ +# HiSparse: Hierarchical Sparse Attention + +HiSparse reduces per-request GPU memory consumption during the decode phase by maintaining only a small "hot" KV buffer on GPU while keeping complete KV data in CPU pinned memory. Combined with PD disaggregation, it enables significantly higher decode concurrency. + +> **Prerequisites**: HiSparse only works with models that use **DeepSeek Sparse Attention (DSA)** architectures (e.g., DeepSeek-V3.2, GLM-5). These models natively select a subset of tokens for attention, making it possible to keep only the top-k KV on GPU while storing the full KV in host memory — without accuracy loss. Additionally, HiSparse currently requires **PD disaggregation mode** and is enabled on the **decode instance** only. + +## Why HiSparse? + +In long-context LLM inference, each decoding request holds a full-length KV cache on GPU, limiting the number of concurrent requests a decode instance can serve. HiSparse addresses this by: + +- **Reducing GPU memory per request**: Each request occupies only a fixed-size device buffer (e.g., 4KB tokens) instead of the full sequence length. +- **On-demand swap-in**: A CUDA kernel dynamically loads the top-k most relevant KV entries from host memory based on attention scores. +- **Transparent to prefill**: HiSparse is entirely a decode-side optimization; the prefill instance requires no changes. + +## Design Overview + +### Decode Workflow + +Each decode step follows this flow: + +1. **Forward decode** — generate the next token +2. **Top-k selection** — select the most relevant token positions via attention scores +3. **Swap-in** — the CUDA kernel loads top-k KV entries from host to device buffer: + - *Short sequences* (`seq_len ≤ device_buffer_size`): fast path, all KV already in buffer + - *Long sequences*: hit detection → LRU reordering → miss handling (host → device copy) +4. **Decode attention** — compute attention using the top-k device locations +5. **Eager backup** — asynchronously copy the previous token's KV from device to host + +### PD Disaggregation Integration (Direct-to-Host) + +In PD disaggregation mode, the prefill instance transfers KV cache directly into the decode instance's host pool via RDMA, bypassing the GPU entirely on the decode side. This eliminates the transient GPU memory spike during KV transfer and removes the staging DMA step. + +``` +Prefill GPU ──RDMA──▶ Decode Host Pool (CPU pinned memory) + │ + ▼ + alloc device buffer (4KB) + │ + ▼ + swap-in kernel (on-demand top-k) +``` + +## Server Arguments + +| Argument | Type / Default | Description | +|----------|---------------|-------------| +| `--enable-hisparse` | flag; default: disabled | Enable HiSparse on the decode instance | +| `--hisparse-config` | JSON string | Configuration for HiSparse (see below) | + +### HiSparse Config Parameters + +Pass as a JSON string via `--hisparse-config`: + +| Parameter | Type / Default | Description | +|-----------|---------------|-------------| +| `top_k` | int | Number of topk entries | +| `device_buffer_size` | int | Number of token slots in the per-request GPU device buffer | +| `host_to_device_ratio` | int | Ratio of logical pool size to device pool size, determining host memory capacity | + +Example: `--hisparse-config='{"top_k": 2048, "device_buffer_size": 4096, "host_to_device_ratio": 5}'` + +## Deployment + +HiSparse currently requires **PD disaggregation mode** and is enabled only on the **decode instance**. + +### Prefill Instance + +```bash +python3 -m sglang.launch_server \ + --model-path /path/to/model \ + --trust-remote-code \ + --port 8000 --host 0.0.0.0 \ + --context-length 81920 \ + --chunked-prefill-size 65536 \ + --tp-size 8 --dp-size 8 --enable-dp-attention \ + --page-size 64 \ + --mem-fraction-static 0.85 \ + --disaggregation-mode prefill \ + --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_2,mlx5_3 \ + --nnodes 1 --node-rank 0 +``` + +### Decode Instance (with HiSparse) + +```bash +python3 -m sglang.launch_server \ + --model-path /path/to/model \ + --trust-remote-code \ + --port 8000 --host 0.0.0.0 \ + --context-length 81920 \ + --tp-size 8 --dp-size 8 --enable-dp-attention \ + --page-size 64 \ + --mem-fraction-static 0.85 \ + --kv-cache-dtype bfloat16 \ + --nsa-decode-backend flashmla_sparse \ + --disaggregation-mode decode \ + --disaggregation-ib-device mlx5_0,mlx5_1,mlx5_2,mlx5_3 \ + --dist-init-addr 127.0.0.1:5757 \ + --nnodes 1 --node-rank 0 \ + --enable-hisparse \ + --hisparse-config='{"top_k": 2048, "device_buffer_size": 6144, "host_to_device_ratio": 5}' +``` + +### Key Notes + +- The prefill instance does not need `--enable-hisparse`; it is unaware of HiSparse. +- On the decode instance, the following flags are **required** for HiSparse: + - `--kv-cache-dtype bfloat16` — currently only bfloat16 KV cache is supported (more dtypes planned). + - `--nsa-decode-backend flashmla_sparse` — currently only `flashmla_sparse` backend is supported. + - `--enable-hisparse` — enables HiSparse. + - `--hisparse-config` — HiSparse configuration (top_k, device_buffer_size, host_to_device_ratio). diff --git a/docs/basic_usage/deepseek_v32.md b/docs/basic_usage/deepseek_v32.md index 4cc27ffd38cc..db1d0e71ea7a 100644 --- a/docs/basic_usage/deepseek_v32.md +++ b/docs/basic_usage/deepseek_v32.md @@ -468,3 +468,9 @@ python -m sglang.launch_server \ ``` For the Decode nodes, it is recommended to use the **EP mode**. + +## HiSparse: Hierarchical Sparse Attention for DSA (experimental) + +HiSparse reduces per-request GPU memory during decode by keeping only a small "hot" KV buffer on GPU while storing complete KV data in CPU pinned memory. A CUDA kernel dynamically swaps in the top-k most relevant KV entries from host memory on each decode step. This enables significantly higher decode concurrency for long-context DSA models. + +HiSparse currently requires PD disaggregation mode and is enabled on the decode instance only. For detailed design, configuration, and deployment instructions, see the [HiSparse Guide](../advanced_features/hisparse_guide.md). diff --git a/docs/supported_models/text_generation/multimodal_language_models.md b/docs/supported_models/text_generation/multimodal_language_models.md index 77020fa09d95..b8648590cb12 100644 --- a/docs/supported_models/text_generation/multimodal_language_models.md +++ b/docs/supported_models/text_generation/multimodal_language_models.md @@ -51,9 +51,38 @@ in the GitHub search bar. | **Ernie4.5-VL** | `baidu/ERNIE-4.5-VL-28B-A3B-PT` | Baidu's vision-language models(28B,424B). Support image and video comprehension, and also support thinking. | | | **JetVLM** | | JetVLM is an vision-language model designed for high-performance multimodal understanding and generation tasks built upon Jet-Nemotron. | Coming soon | | **Step3-VL** (10B) | `stepfun-ai/Step3-VL-10B` | StepFun's lightweight open-source 10B parameter VLM for multimodal intelligence, excelling in visual perception, complex reasoning, and human alignment. | | +| **Qwen3-ASR** (0.6B, 1.7B) | `Qwen/Qwen3-ASR-1.7B` | Alibaba's automatic speech recognition models supporting 52 languages. Served via the `/v1/audio/transcriptions` endpoint. | | | **Qwen3-Omni** | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Alibaba's omni-modal MoE model. Currently supports the **Thinker** component (multimodal understanding for text, images, audio, and video), while the **Talker** component (audio generation) is not yet supported. | | | **LFM2-VL** | `LiquidAI/LFM2.5-VL-1.6B` | Liquid AI's vision-language model combining a SigLip2 vision encoder (NaFlex variable-resolution) with the LFM2 hybrid attention + short convolution language model. Supports multi-image inputs. | | +## Audio Transcription + +SGLang supports audio-only ASR models via the OpenAI-compatible `/v1/audio/transcriptions` endpoint. Upload an audio file and receive a transcription. + +### Launch Command + +```shell +sglang serve \ + --model-path Qwen/Qwen3-ASR-1.7B \ + --served-model-name qwen3-asr \ + --trust-remote-code \ + --host 0.0.0.0 --port 30000 +``` + +### Example Request + +```bash +curl http://localhost:30000/v1/audio/transcriptions \ + -F file=@audio.wav \ + -F model=qwen3-asr \ + -F response_format=verbose_json +``` + +| Model Family | Example Identifier | Notes | +|--------------|--------------------|-------| +| **Whisper** | `openai/whisper-large-v3` | OpenAI's speech recognition model. | +| **Qwen3-ASR** (0.6B, 1.7B) | `Qwen/Qwen3-ASR-1.7B` | Use `--trust-remote-code`. Supports 52 languages. | + ## Video Input Support SGLang supports video input for Vision-Language Models (VLMs), enabling temporal reasoning tasks such as video question answering, captioning, and holistic scene understanding. Video clips are decoded, key frames are sampled, and the resulting tensors are batched together with the text prompt, allowing multimodal inference to integrate visual and linguistic context. diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 3261e5f30a88..1cabc32c131b 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -134,6 +134,8 @@ srt_mps = [ "torchao==0.9.0", "torchaudio==2.9.1", "torchvision", + "mlx", + "mlx-lm", ] diffusion_mps = [ diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 850270308d2b..86a0fc15b287 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -404,6 +404,19 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="gemma-4-it", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", ""), + "user": ("<|turn>user\n", "\n"), + "assistant": ("<|turn>assistant\n", "\n"), + }, + style=ChatTemplateStyle.PLAIN, + ) +) + register_chat_template( ChatTemplate( name="dbrx-instruct", @@ -611,8 +624,10 @@ def match_chat_yi(model_path: str): @register_chat_template_matching_function -def match_gemma_it(model_path: str): - if re.search(r"gemma.*it", model_path, re.IGNORECASE): +def match_gemma(model_path: str): + if re.search(r"gemma-4.*it", model_path, re.IGNORECASE): + return "gemma-4-it" + if re.search(r"(gemma.*it)|(gemma-3)", model_path, re.IGNORECASE): return "gemma-it" @@ -636,12 +651,6 @@ def match_granite_instruct(model_path: str): return "granite-3-instruct" -@register_chat_template_matching_function -def match_gemma3_instruct(model_path: str): - if re.search(r"gemma-3", model_path, re.IGNORECASE): - return "gemma-it" - - @register_chat_template_matching_function def match_internvl_chat(model_path: str): if re.search(r"internvl2_5", model_path, re.IGNORECASE): diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index b43b835c4880..dff1f81ff21a 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -23,6 +23,7 @@ from sglang.srt.configs.nemotron_h import NemotronHConfig from sglang.srt.configs.olmo3 import Olmo3Config from sglang.srt.configs.qwen3_5 import Qwen3_5Config, Qwen3_5MoeConfig +from sglang.srt.configs.qwen3_asr import Qwen3ASRConfig from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.step3_vl import ( Step3TextConfig, @@ -63,4 +64,5 @@ "JetNemotronConfig", "JetVLMConfig", "Step3p5Config", + "Qwen3ASRConfig", ] diff --git a/python/sglang/srt/configs/linear_attn_model_registry.py b/python/sglang/srt/configs/linear_attn_model_registry.py new file mode 100644 index 000000000000..33fdae8f0783 --- /dev/null +++ b/python/sglang/srt/configs/linear_attn_model_registry.py @@ -0,0 +1,72 @@ +"""Registry for linear attention hybrid models (softmax + linear attention). + +External models can register themselves without modifying SGLang core files: + + from sglang.srt.configs.linear_attn_model_registry import ( + register_linear_attn_model, LinearAttnModelSpec, + ) + + register_linear_attn_model(LinearAttnModelSpec( + config_class=MyLinearAttnConfig, + backend_class_name="sglang.srt.layers.attention.linear.kda_backend.KDAAttnBackend", + arch_names=["MyLinearAttnForCausalLM"], + uses_mamba_radix_cache=True, + support_mamba_cache=True, + )) +""" + +from __future__ import annotations + +import importlib +import logging +from dataclasses import dataclass, field +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class LinearAttnModelSpec: + """Specification for a hybrid (softmax + linear attention) model.""" + + config_class: type + backend_class_name: str # fully-qualified class name, lazily imported + arch_names: list[str] = field(default_factory=list) + uses_mamba_radix_cache: bool = True + support_mamba_cache: bool = True + support_mamba_cache_extra_buffer: bool = False + unwrap_text_config: bool = False # call get_text_config() before isinstance check + + +_LINEAR_ATTN_MODEL_REGISTRY: list[LinearAttnModelSpec] = [] + + +def register_linear_attn_model(spec: LinearAttnModelSpec) -> None: + _LINEAR_ATTN_MODEL_REGISTRY.append(spec) + logger.info( + "Registered linear attn model: config=%s, backend=%s, archs=%s", + spec.config_class.__name__, + spec.backend_class_name.rsplit(".", 1)[-1], + spec.arch_names, + ) + + +def get_linear_attn_config(hf_config: Any) -> Optional[tuple[LinearAttnModelSpec, Any]]: + for spec in _LINEAR_ATTN_MODEL_REGISTRY: + config = hf_config.get_text_config() if spec.unwrap_text_config else hf_config + if isinstance(config, spec.config_class): + return spec, config + return None + + +def get_linear_attn_spec_by_arch(arch_name: str) -> Optional[LinearAttnModelSpec]: + for spec in _LINEAR_ATTN_MODEL_REGISTRY: + if arch_name in spec.arch_names: + return spec + return None + + +def import_backend_class(dotted_name: str) -> type: + module_path, class_name = dotted_name.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 89e90516ef12..450ea6b12c2b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -196,8 +196,16 @@ def __init__( self.is_image_understandable_model = enable_multimodal and hasattr( self.hf_config, "vision_config" ) - self.is_audio_understandable_model = enable_multimodal and hasattr( - self.hf_config, "audio_config" + + # Models expose audio_config at different nesting levels: + # - top-level audio_config: e.g. Qwen2Audio + # - thinker_config.audio_config: Qwen3-Omni, Qwen3-ASR (nested thinker arch) + # - is_audio_model(): Whisper, Qwen3-ASR (architecture-based fallback)\ + # TODO: Handle this more robustly by standardizing the config structure in the future + self.is_audio_understandable_model = enable_multimodal and ( + hasattr(self.hf_config, "audio_config") + or hasattr(getattr(self.hf_config, "thinker_config", None), "audio_config") + or is_audio_model(self.hf_config.architectures) ) self.is_multimodal_chunked_prefill_supported = ( @@ -376,6 +384,8 @@ def _derive_hybrid_model(self): self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [ "MiMoV2FlashForCausalLM", "MiMoV2MTP", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", ] def _derive_context_length(self, context_length: int): @@ -433,7 +443,7 @@ def _derive_model_shapes(self): self.swa_v_head_dim = getattr( self.hf_text_config, "swa_v_head_dim", - self.v_head_dim, + self.swa_head_dim, ) # FIXME: temporary special judge for MLA architecture if ( @@ -1301,6 +1311,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Ernie4_5_VLMoeForConditionalGeneration", "Gemma3ForConditionalGeneration", "Gemma3nForConditionalGeneration", + "Gemma4ForConditionalGeneration", "Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration", @@ -1329,6 +1340,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", + "Qwen3ASRForConditionalGeneration", "Qwen3OmniMoeForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", @@ -1377,6 +1389,7 @@ def is_multimodal_model(model_architectures: List[str]): def is_audio_model(model_architectures: List[str]): models = [ "WhisperForConditionalGeneration", + "Qwen3ASRForConditionalGeneration", ] return any(model in model_architectures for model in models) @@ -1447,6 +1460,8 @@ def is_hybrid_swa_model(model_architectures: List[str]): "MiMoV2MTP", "Step3p5ForCausalLM", "Step3p5MTP", + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", } return any(arch in hybrid_swa_archs for arch in model_architectures) @@ -1464,7 +1479,7 @@ def get_hybrid_layer_ids( i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 ] elif "GptOssForCausalLM" in model_architectures: - layer_types = getattr(hf_text_config, "layer_types", None) + layer_types = getattr(hf_text_config, "layer_types", []) swa_attention_layer_ids = [ i for i, x in enumerate(layer_types) if x == "sliding_attention" ] @@ -1497,6 +1512,17 @@ def get_hybrid_layer_ids( elif "Step3p5MTP" in model_architectures: swa_attention_layer_ids = [0] full_attention_layer_ids = [] + elif ( + "Gemma4ForCausalLM" in model_architectures + or "Gemma4ForConditionalGeneration" in model_architectures + ): + layer_types = getattr(hf_text_config, "layer_types", []) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] else: swa_attention_layer_ids = None full_attention_layer_ids = None diff --git a/python/sglang/srt/configs/qwen3_asr.py b/python/sglang/srt/configs/qwen3_asr.py new file mode 100644 index 000000000000..048eb2d9704d --- /dev/null +++ b/python/sglang/srt/configs/qwen3_asr.py @@ -0,0 +1,172 @@ +import torch +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoTokenizer, + PretrainedConfig, + ProcessorMixin, +) + +from sglang.srt.configs.qwen3_omni import Qwen3OmniMoeAudioEncoderConfig +from sglang.srt.multimodal.customized_mm_processor_utils import ( + register_customized_processor, +) +from sglang.utils import logger + + +class Qwen3ASRThinkerConfig(PretrainedConfig): + model_type = "qwen3_asr_thinker" + sub_configs = { + "audio_config": Qwen3OmniMoeAudioEncoderConfig, + } + + def __init__( + self, + audio_config=None, + text_config=None, + audio_token_id=151676, + audio_start_token_id=151669, + audio_end_token_id=151670, + **kwargs, + ): + super().__init__(**kwargs) + + if isinstance(audio_config, dict): + audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen3OmniMoeAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + from transformers.models.qwen3.configuration_qwen3 import ( + Qwen3Config as HFQwen3Config, + ) + + text_config = HFQwen3Config(**text_config) + elif text_config is None: + raise ValueError( + "Qwen3ASRThinkerConfig requires a text_config dict with " + "model parameters (hidden_size, num_attention_heads, etc.). " + "Got None." + ) + + self.text_config = text_config + + self.audio_token_id = audio_token_id + self.audio_start_token_id = audio_start_token_id + self.audio_end_token_id = audio_end_token_id + + +class Qwen3ASRConfig(PretrainedConfig): + model_type = "qwen3_asr" + sub_configs = { + "thinker_config": Qwen3ASRThinkerConfig, + } + + def __init__(self, thinker_config=None, **kwargs): + super().__init__(**kwargs) + if thinker_config is None: + thinker_config = {} + logger.info( + "thinker_config is None. " + "Initializing Qwen3-ASR thinker with default values" + ) + if isinstance(thinker_config, dict): + self.thinker_config = Qwen3ASRThinkerConfig(**thinker_config) + else: + self.thinker_config = thinker_config + + def get_text_config(self, decoder=False) -> PretrainedConfig: + return self.thinker_config.text_config + + +class Qwen3ASRProcessor(ProcessorMixin): + """Minimal composite processor: WhisperFeatureExtractor + Qwen2Tokenizer. + + AutoProcessor.from_pretrained() for Qwen3-ASR returns just a tokenizer, + but SGLang's multimodal pipeline needs a processor that handles audio. + """ + + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "WhisperFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor=None, tokenizer=None, **kwargs): + super().__init__(feature_extractor=feature_extractor, tokenizer=tokenizer) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + trust_remote_code = kwargs.pop("trust_remote_code", True) + feature_extractor = AutoFeatureExtractor.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + ) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + ) + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) + + def _get_feat_extract_output_lengths(self, input_lengths): + if not isinstance(input_lengths, torch.Tensor): + input_lengths = torch.tensor(input_lengths) + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + + def __call__(self, text=None, audio=None, audio_kwargs=None, **kwargs): + inputs = {} + if audio is not None: + audio_kwargs = audio_kwargs or {} + audio_inputs = self.feature_extractor( + audio, + sampling_rate=self.feature_extractor.sampling_rate, + return_attention_mask=True, + return_tensors=kwargs.get("return_tensors"), + **audio_kwargs, + ) + inputs["input_features"] = audio_inputs["input_features"] + if "attention_mask" in audio_inputs: + inputs["feature_attention_mask"] = audio_inputs["attention_mask"] + + if text is not None: + text_inputs = self.tokenizer( + text, + return_tensors=kwargs.get("return_tensors"), + padding=kwargs.get("padding", False), + ) + input_ids = text_inputs["input_ids"] + + # Expand the single <|audio_pad|> placeholder in the prompt to N + # copies, where N is the audio encoder's output length for this clip. + # Without this, the model only sees 1 audio token for hundreds of + # feature frames and can't align audio embeddings with token positions. + if audio is not None and "feature_attention_mask" in inputs: + audio_pad_id = self.tokenizer.convert_tokens_to_ids("<|audio_pad|>") + feat_lengths = inputs["feature_attention_mask"].sum(dim=-1) + audio_token_counts = self._get_feat_extract_output_lengths(feat_lengths) + expanded = [] + for seq_idx in range(input_ids.shape[0]): + ids = input_ids[seq_idx].tolist() + audio_idx = 0 + new_ids = [] + for tid in ids: + if tid == audio_pad_id and audio_idx < len(audio_token_counts): + n = int(audio_token_counts[audio_idx].item()) + new_ids.extend([audio_pad_id] * n) + audio_idx += 1 + else: + new_ids.append(tid) + expanded.append(new_ids) + max_len = max(len(s) for s in expanded) + pad_id = self.tokenizer.pad_token_id or 0 + padded = [s + [pad_id] * (max_len - len(s)) for s in expanded] + input_ids = torch.tensor(padded, dtype=torch.long) + + inputs["input_ids"] = input_ids + return inputs + + +AutoConfig.register("qwen3_asr", Qwen3ASRConfig) +AutoConfig.register("qwen3_asr_thinker", Qwen3ASRThinkerConfig) +register_customized_processor(Qwen3ASRProcessor)(Qwen3ASRConfig) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 72b9f400aafd..5d42150d0563 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -476,8 +476,8 @@ def _get_feat_extract_output_lengths(self, feature_lens): if self.model_type in ["qwen2_audio", "qwen2_5_omni"]: input_length = (feature_lens - 1) // 2 + 1 return (input_length - 2) // 2 + 1 - # qwen3_omni_moe - elif self.model_type == "qwen3_omni_moe": + # qwen3_asr / qwen3_omni_moe (same audio encoder architecture) + elif self.model_type in ["qwen3_asr", "qwen3_omni_moe"]: input_lengths_leave = feature_lens % 100 feat_lengths = (input_lengths_leave - 1) // 2 + 1 output_lengths = ( diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 38a4d15cf048..f84353f6dc12 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -477,25 +477,35 @@ def send_kvcache_slice( # Get configuration from kv_args local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size dst_tp_rank_in_group = decode_tp_rank % decode_tp_size - num_kv_heads = self.kv_args.kv_head_num - - # Calculate head distribution - src_heads_per_rank = num_kv_heads - dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size src_kv_item_len = self.kv_args.kv_item_lens[0] page_size = self.kv_args.page_size + # Use total KV head count (not per-rank) for correct head distribution. + # Per-rank kv_head_num is max(1, total//tp) which loses info when total < tp. + total_kv_heads = getattr(self.kv_args, "total_kv_head_num", 0) + if total_kv_heads <= 0: + total_kv_heads = self.kv_args.kv_head_num * prefill_tp_size + + src_heads_per_rank = max(1, total_kv_heads // prefill_tp_size) + dst_heads_per_rank = max(1, total_kv_heads // decode_tp_size) + bytes_per_head_slice_to_send = ( dst_kv_item_len // page_size // dst_heads_per_rank ) + # GQA replication: how many prefill ranks share the same KV head + src_replication = max(1, prefill_tp_size // total_kv_heads) + # Determine which heads to send if prefill_tp_size > decode_tp_size: # Multiple prefill ranks to one decode rank src_head_start_offset = 0 num_heads_to_send = src_heads_per_rank - dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank + unique_head_idx = local_tp_rank_in_group // src_replication + dst_head_start_offset = ( + unique_head_idx * src_heads_per_rank + ) % dst_heads_per_rank else: # Send KVCache from 1 prefill instance to multiple decode instances src_head_start_offset = ( @@ -748,7 +758,9 @@ def add_transfer_request( assert len(chunked_dst_kv_indice) == len(kv_indices) assert req.agent_name in self.decode_kv_args_table - notif = f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}" + notif = ( + f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}" + ) decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): @@ -789,7 +801,7 @@ def add_transfer_request( dst_info.dst_state_data_ptrs, req.dst_state_indices, dst_info.gpu_id, - f"{req.room}_state_{self.kv_args.pp_rank}", + f"{req.room}_state_{self.kv_args.engine_rank}", decode_tp_size, ) if state_xfer_handle is not None: diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 6c255d123fb0..6bfd678f9629 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -120,6 +120,11 @@ def __init__( and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" ) + self.is_gemma4 = ( + hasattr(self.tokenizer_manager.model_config, "hf_config") + and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type") + and self.tokenizer_manager.model_config.hf_config.model_type == "gemma4" + ) self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() @@ -331,7 +336,7 @@ def _process_messages( ) -> MessageProcessingResult: """Process chat messages and apply chat template""" # GptOss model needs to keep special tokens for harmony parsing - if self.is_gpt_oss: + if self.is_gpt_oss or self.is_gemma4: request.skip_special_tokens = False self._patch_mistral_skip_special_tokens(request) @@ -1285,12 +1290,18 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """ if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3"]: + + if self.reasoning_parser == "deepseek-v3": # Models that require explicit enable thinking (thinking=True) return ( request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True ) + if self.reasoning_parser == "gemma4": + return ( + request.chat_template_kwargs is not None + and request.chat_template_kwargs.get("enable_thinking") is True + ) if self.reasoning_parser in ["kimi_k2"]: # Models that thinking by default, and can be disabled by setting thinking=False return ( diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index bfbad1e0d321..1040122b2e15 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -50,12 +50,22 @@ TIMESTAMP_BASE_TOKEN_ID = 50365 # <|0.00|> TIMESTAMP_BASE_OFFSET = 0.02 # Each token step = 0.02 seconds +_QWEN3_ASR_TEXT_TAG = "" + + +def _detect_model_family(model_config) -> str: + archs = getattr(getattr(model_config, "hf_config", None), "architectures", []) or [] + if "Qwen3ASRForConditionalGeneration" in archs: + return "qwen3_asr" + return "whisper" + class OpenAIServingTranscription(OpenAIServingBase): """Handler for /v1/audio/transcriptions requests""" def __init__(self, tokenizer_manager: TokenizerManager): super().__init__(tokenizer_manager) + self._model_family = _detect_model_family(tokenizer_manager.model_config) def _request_id_prefix(self) -> str: return "trsc-" @@ -71,6 +81,27 @@ def _convert_to_internal_request( raw_request: Request = None, ) -> tuple[GenerateReqInput, TranscriptionRequest]: """Convert transcription request to internal format.""" + if self._model_family == "qwen3_asr": + prompt = ( + "<|im_start|>user\n" + "<|audio_start|><|audio_pad|><|audio_end|>" + "<|im_end|>\n" + "<|im_start|>assistant\n" + ) + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": 1024, + } + adapted_request = GenerateReqInput( + text=prompt, + audio_data=request.audio_data, + sampling_params=sampling_params, + stream=request.stream, + modalities=["audio"], + routing_key=self.extract_routing_key(raw_request), + ) + return adapted_request, request + # Build sampling params - include language for WhisperProcessor sampling_params = { "temperature": request.temperature, @@ -232,6 +263,8 @@ async def _handle_non_streaming_request( return self.create_error_response(str(e)) text = ret.get("text", "") + if self._model_family == "qwen3_asr": + text = _postprocess_qwen3_asr(text) usage = TranscriptionUsage(seconds=int(math.ceil(request.audio_duration_s))) # Build response based on format @@ -239,15 +272,22 @@ async def _handle_non_streaming_request( return Response(content=text, media_type="text/plain") if request.response_format == "verbose_json": - output_ids = ret.get("output_ids", []) - tokenizer = self.tokenizer_manager.tokenizer - parsed_text, segments = self._parse_segments(output_ids, tokenizer) - + if self._model_family == "whisper": + output_ids = ret.get("output_ids", []) + tokenizer = self.tokenizer_manager.tokenizer + parsed_text, segments = self._parse_segments(output_ids, tokenizer) + return TranscriptionVerboseResponse( + language=request.language or "en", + duration=round(request.audio_duration_s, 2), + text=parsed_text or text, + segments=segments, + usage=usage, + ) return TranscriptionVerboseResponse( - language=request.language or "en", + language=request.language, duration=round(request.audio_duration_s, 2), - text=parsed_text or text, - segments=segments, + text=text, + segments=[], usage=usage, ) @@ -324,3 +364,13 @@ async def _generate_transcription_stream( yield f"data: {error}\n\n" yield "data: [DONE]\n\n" + + +# TODO (adityavaid): refactor model-specific postprocessing into a plugin/adapter mechanism. +def _postprocess_qwen3_asr(text: str) -> str: + if not text: + return "" + if _QWEN3_ASR_TEXT_TAG in text: + _, text_part = text.rsplit(_QWEN3_ASR_TEXT_TAG, 1) + return text_part.strip() + return text.strip() diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index ca066e196d0f..84196d8cb057 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -14,6 +14,7 @@ from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import Gemma4Detector from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -69,6 +70,7 @@ class FunctionCallParser: "interns1": InternlmDetector, "hermes": HermesDetector, "gigachat3": GigaChat3Detector, + "gemma4": Gemma4Detector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/python/sglang/srt/function_call/gemma4_detector.py b/python/sglang/srt/function_call/gemma4_detector.py new file mode 100644 index 000000000000..2b4b9e05a16b --- /dev/null +++ b/python/sglang/srt/function_call/gemma4_detector.py @@ -0,0 +1,445 @@ +import json +import logging +from typing import List, Optional + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + +# Gemma4 special tokens for tool calls +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" +STRING_DELIM = '<|"|>' + + +def _parse_gemma4_value(value_str: str) -> object: + """Parse a single Gemma4 value (after key:) into a Python object.""" + value_str = value_str.strip() + if not value_str: + return value_str + + # Boolean + if value_str == "true": + return True + if value_str == "false": + return False + + # Number (int or float) + try: + if "." in value_str: + return float(value_str) + return int(value_str) + except ValueError: + pass + + # Bare string (no <|"|> delimiters) + return value_str + + +def _parse_gemma4_array(arr_str: str) -> list: + """Parse a Gemma4 array content string into a Python list.""" + items: list = [] + i = 0 + n = len(arr_str) + + while i < n: + while i < n and arr_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # String element + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + end_pos = arr_str.find(STRING_DELIM, i) + if end_pos == -1: + items.append(arr_str[i:]) + break + items.append(arr_str[i:end_pos]) + i = end_pos + len(STRING_DELIM) + + # Nested object + elif arr_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + next_delim = arr_str.find(STRING_DELIM, i) + i = next_delim + len(STRING_DELIM) if next_delim != -1 else n + continue + if arr_str[i] == "{": + depth += 1 + elif arr_str[i] == "}": + depth -= 1 + i += 1 + items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) + + # Nested array + elif arr_str[i] == "[": + depth = 1 + sub_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i] == "[": + depth += 1 + elif arr_str[i] == "]": + depth -= 1 + i += 1 + items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) + + # Bare value + else: + val_start = i + while i < n and arr_str[i] not in (",", "]"): + i += 1 + items.append(_parse_gemma4_value(arr_str[val_start:i])) + + return items + + +def _parse_gemma4_args(args_str: str) -> dict: + """Parse Gemma4's custom key:value format into a Python dict.""" + if not args_str or not args_str.strip(): + return {} + + result: dict = {} + i = 0 + n = len(args_str) + + while i < n: + # Skip whitespace and commas + while i < n and args_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # Parse key (unquoted, ends at ':') + key_start = i + while i < n and args_str[i] != ":": + i += 1 + if i >= n: + break + key = args_str[key_start:i].strip() + i += 1 # skip ':' + + # Parse value + if i >= n: + result[key] = "" + break + + # Skip whitespace after ':' + while i < n and args_str[i] in (" ", "\n", "\t"): + i += 1 + if i >= n: + result[key] = "" + break + + # String value: <|"|>...<|"|> + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + val_start = i + end_pos = args_str.find(STRING_DELIM, i) + if end_pos == -1: + # Unterminated string — take rest + result[key] = args_str[val_start:] + break + result[key] = args_str[val_start:end_pos] + i = end_pos + len(STRING_DELIM) + + # Nested object: {...} + elif args_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + # Skip over string contents + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "{": + depth += 1 + elif args_str[i] == "}": + depth -= 1 + i += 1 + result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) + + # Array: [...] + elif args_str[i] == "[": + depth = 1 + arr_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i : i + len(STRING_DELIM)] == STRING_DELIM: + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + if next_delim == -1: + i = n + else: + i = next_delim + len(STRING_DELIM) + continue + if args_str[i] == "[": + depth += 1 + elif args_str[i] == "]": + depth -= 1 + i += 1 + arr_content = args_str[arr_start : i - 1] + result[key] = _parse_gemma4_array(arr_content) + + # Bare value (number, boolean, etc.) + else: + val_start = i + while i < n and args_str[i] not in (",", "}", "]"): + i += 1 + result[key] = _parse_gemma4_value(args_str[val_start:i]) + + return result + + +def _find_matching_brace(text: str) -> int: + """Find index of matching '}' in text, respecting STRING_DELIM and nesting. + + Assumes text starts just after the opening '{'. + Returns index of closing brace, or -1 if not found (incomplete). + """ + depth = 1 + i = 0 + n = len(text) + delim_len = len(STRING_DELIM) + while i < n and depth > 0: + if text[i : i + delim_len] == STRING_DELIM: + i += delim_len + next_delim = text.find(STRING_DELIM, i) + if next_delim == -1: + return -1 + i = next_delim + delim_len + continue + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + i += 1 + return (i - 1) if depth == 0 else -1 + + +class Gemma4Detector(BaseFormatDetector): + def __init__(self): + super().__init__() + self.tool_call_start_token = TOOL_CALL_START + self.tool_call_end_token = TOOL_CALL_END + + # Streaming state + self.parsed_pos: int = 0 + self.is_inside_tool_call: bool = False + self.current_func_name: Optional[str] = None + self._tool_indices: Optional[dict] = None + + @staticmethod + def _extract_tool_calls(text: str) -> list: + """Extract (func_name, args_str) pairs using brace-balanced parsing.""" + results = [] + search_from = 0 + while True: + start = text.find(TOOL_CALL_START, search_from) + if start == -1: + break + end = text.find(TOOL_CALL_END, start) + if end == -1: + break + inner = text[start + len(TOOL_CALL_START) : end] + if inner.startswith("call:"): + brace = inner.find("{") + if brace != -1: + func_name = inner[5:brace] + args_content = inner[brace + 1 :] + match_idx = _find_matching_brace(args_content) + args_str = ( + args_content[:match_idx] if match_idx != -1 else args_content + ) + results.append((func_name, args_str)) + search_from = end + len(TOOL_CALL_END) + return results + + def has_tool_call(self, text: str) -> bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + if self.tool_call_start_token not in text: + return StreamingParseResult(normal_text=text) + + calls = [] + try: + matches = self._extract_tool_calls(text) + if not matches: + return StreamingParseResult(normal_text=text) + + tool_indices = self._get_tool_indices(tools) + for func_name, args_str in matches: + arguments = _parse_gemma4_args(args_str) + calls.append( + ToolCallItem( + tool_index=tool_indices.get(func_name, -1), + name=func_name, + parameters=json.dumps(arguments, ensure_ascii=False), + ) + ) + + # Content = text before first tool call + content_end = text.find(self.tool_call_start_token) + normal_text = text[:content_end] if content_end > 0 else "" + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) + return StreamingParseResult(normal_text=text) + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + self._buffer += new_text + + if not self._buffer: + return StreamingParseResult() + + calls = [] + normal_text_chunks = [] + if self._tool_indices is None: + self._tool_indices = self._get_tool_indices(tools) + + try: + while True: + current_slice = self._buffer[self.parsed_pos :] + if not current_slice: + break + + if not self.is_inside_tool_call: + # Outside tool call block + next_start = current_slice.find(self.tool_call_start_token) + if next_start == -1: + # Check for partial match at the end + partial_len = self._ends_with_partial_token( + current_slice, self.tool_call_start_token + ) + if partial_len > 0: + text_to_append = current_slice[:-partial_len] + if text_to_append: + normal_text_chunks.append(text_to_append) + self.parsed_pos += len(text_to_append) + break + else: + normal_text_chunks.append(current_slice) + self.parsed_pos += len(current_slice) + continue + elif next_start == 0: + self.parsed_pos += len(self.tool_call_start_token) + self.is_inside_tool_call = True + continue + else: + normal_text_chunks.append(current_slice[:next_start]) + self.parsed_pos += next_start + continue + else: + # Inside tool call block + + # Check for TOOL_CALL_END first + if current_slice.startswith(self.tool_call_end_token): + self.parsed_pos += len(self.tool_call_end_token) + self.is_inside_tool_call = False + self.current_func_name = None + continue + + if not self.current_func_name: + # Skip leading whitespace + if current_slice[0] in (" ", "\n", "\t"): + self.parsed_pos += 1 + continue + + if current_slice.startswith("call:"): + brace_pos = current_slice.find("{") + if brace_pos != -1: + func_name = current_slice[5:brace_pos] + self.current_tool_id += 1 + self.current_func_name = func_name + self.current_tool_name_sent = True + + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get( + func_name, -1 + ), + name=func_name, + parameters="", + ) + ) + self.parsed_pos += brace_pos + 1 + continue + else: + # Incomplete call:name{ + break + else: + # Check for partial matches + if "call:".startswith( + current_slice + ) or self.tool_call_end_token.startswith(current_slice): + break + + # Unexpected content, skip + self.parsed_pos += 1 + continue + else: + # Parsing arguments (looking for balancing }) + match_idx = _find_matching_brace(current_slice) + if match_idx != -1: + args_str = current_slice[:match_idx] + arguments = _parse_gemma4_args(args_str) + + calls.append( + ToolCallItem( + tool_index=self._tool_indices.get( + self.current_func_name, -1 + ), + parameters=json.dumps( + arguments, ensure_ascii=False + ), + ) + ) + self.parsed_pos += match_idx + 1 + self.current_func_name = None + continue + else: + # Incomplete arguments block + break + + except (ValueError, IndexError, TypeError, KeyError) as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + # Reset parser state to prevent corruption + self.is_inside_tool_call = False + self.current_func_name = None + self._buffer = "" + self.parsed_pos = 0 + + if self.parsed_pos > 0: + self._buffer = self._buffer[self.parsed_pos :] + self.parsed_pos = 0 + + normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" + return StreamingParseResult(calls=calls, normal_text=normal_text) + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 2353c15993fd..0a5920575c0f 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,6 +1,11 @@ import logging from typing import TYPE_CHECKING +from sglang.srt.configs.linear_attn_model_registry import ( + get_linear_attn_config, + import_backend_class, +) + logger = logging.getLogger(__name__) @@ -225,9 +230,17 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac elif runner.hybrid_lightning_config is not None: linear_attn_backend = LightningAttentionBackend(runner) else: - raise ValueError( - "Expected hybrid GDN or NemotronH models, but got unknown model." - ) + spec_result = get_linear_attn_config(runner.model_config.hf_config) + if spec_result is not None: + spec, _ = spec_result + BackendClass = import_backend_class(spec.backend_class_name) + linear_attn_backend = BackendClass(runner) + else: + raise ValueError( + "Expected hybrid GDN or NemotronH models, but got unknown model. " + "If this is a custom hybrid model, use register_linear_attn_model() " + "from sglang.srt.configs.linear_attn_model_registry." + ) full_attn_layers = cfg.full_attention_layer_ids return HybridLinearAttnBackend( full_attn_backend, linear_attn_backend, full_attn_layers diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 18ed55572cfe..c1f1f48fb789 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices from sglang.srt.utils import ( @@ -51,6 +52,8 @@ class ForwardMetadata: window_kv_indices: torch.Tensor window_num_kv_splits: torch.Tensor window_kv_offsets: torch.Tensor + # Separate attn_logits for SWA layers when v_head_dim differs + swa_attn_logits: Optional[torch.Tensor] = None class TritonAttnBackend(AttentionBackend): @@ -94,16 +97,30 @@ def __init__( self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - if ( + # The decode triton kernel derives attn_lse offsets from attn_logits + # strides via integer division by v_head_dim (the "// Lv" trick in + # _fwd_kernel_stage1/stage2), so attn_logits.shape[-1] must exactly + # match the layer's v_head_dim. For hybrid SWA models where SWA and + # full-attention layers use different v_head_dim (e.g. Gemma 4: + # swa=256, full=512), we allocate a second buffer for SWA layers. + full_v_head_dim = model_runner.model_config.v_head_dim + swa_v_head_dim = model_runner.model_config.swa_v_head_dim + if self.sliding_window_size is not None and swa_v_head_dim != full_v_head_dim: + self.v_head_dim = full_v_head_dim + self.swa_v_head_dim = swa_v_head_dim + elif ( model_runner.hybrid_gdn_config is not None or model_runner.kimi_linear_config is not None + or model_runner.linear_attn_model_spec is not None ): # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + self.swa_v_head_dim = None else: self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ -1 ] + self.swa_v_head_dim = None self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) @@ -242,6 +259,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None spec_info = forward_batch.spec_info if forward_batch.forward_mode.is_decode_or_idle(): @@ -290,6 +308,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + swa_attn_logits = torch.empty( + (bs, self.num_head, self.max_kv_splits, self.swa_v_head_dim), + dtype=torch.float32, + device=self.device, + ) + else: + swa_attn_logits = None attn_lse = torch.empty( (bs, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -436,6 +462,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_cuda_graph_state( @@ -450,6 +477,19 @@ def init_cuda_graph_state( dtype=torch.float32, device=self.device, ) + if self.swa_v_head_dim is not None: + self.cuda_graph_swa_attn_logits = torch.zeros( + ( + max_num_tokens, + self.num_head, + self.max_kv_splits, + self.swa_v_head_dim, + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.cuda_graph_swa_attn_logits = None self.cuda_graph_attn_lse = torch.zeros( (max_num_tokens, self.num_head, self.max_kv_splits), dtype=torch.float32, @@ -520,6 +560,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices = None window_num_kv_splits = None window_kv_offsets = None + swa_attn_logits = None if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -558,6 +599,7 @@ def init_forward_metadata_capture_cuda_graph( kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices attn_logits = self.cuda_graph_attn_logits + swa_attn_logits = self.cuda_graph_swa_attn_logits attn_lse = self.cuda_graph_attn_lse max_extend_len = None num_kv_splits = self.cuda_graph_num_kv_splits @@ -659,6 +701,7 @@ def init_forward_metadata_capture_cuda_graph( window_kv_indices, window_num_kv_splits, window_kv_offsets, + swa_attn_logits=swa_attn_logits, ) def init_forward_metadata_replay_cuda_graph( @@ -819,26 +862,37 @@ def forward_extend( else: o = torch.empty_like(q) - # Save KV cache first (must do this before unified kernel) - if save_kv_cache: - if ( - self.use_mla or layer.k_scale is None - ): # Triton MLA currently doesn't support quantized kv cache - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k, - v, - ) - else: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - forward_batch.out_cache_loc, - k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer - v.clone(), - layer.k_scale, - layer.v_scale, - ) + if k is None and v is None: + pool = forward_batch.token_to_kv_pool + cache_loc = forward_batch.out_cache_loc + if isinstance(pool, SWAKVPool) and pool.layers_mapping[layer.layer_id][1]: + cache_loc = pool.translate_loc_from_full_to_swa(cache_loc) + k_buffer, v_buffer = pool.get_kv_buffer(layer.layer_id) + k = k_buffer[cache_loc] + v = v_buffer[cache_loc] + elif k is None or v is None: + raise ValueError("Both k and v should be None or not None") + else: + # Save KV cache first (must do this before unified kernel) + if save_kv_cache: + if ( + self.use_mla or layer.k_scale is None + ): # Triton MLA currently doesn't support quantized kv cache + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k, + v, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + forward_batch.out_cache_loc, + k.clone(), # cloned to protect k,v from in-place mutation in set_kv_buffer + v.clone(), + layer.k_scale, + layer.v_scale, + ) logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) @@ -1089,6 +1143,16 @@ def forward_decode( k_descale = 1.0 v_descale = 1.0 + # Select the correctly-sized attn_logits buffer for this layer. + # The triton kernel's // Lv stride trick requires attn_logits.shape[-1] + # to exactly match the layer's v_head_dim. + attn_logits = self.forward_metadata.attn_logits + if ( + self.forward_metadata.swa_attn_logits is not None + and layer.v_head_dim == self.swa_v_head_dim + ): + attn_logits = self.forward_metadata.swa_attn_logits + self.decode_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), @@ -1096,7 +1160,7 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), kv_indptr, kv_indices, - self.forward_metadata.attn_logits, + attn_logits, self.forward_metadata.attn_lse, self.forward_metadata.num_kv_splits, self.max_kv_splits, diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index ac0fc72af140..a50b89787f2a 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -168,13 +168,14 @@ def _fwd_kernel( def context_attention_fwd( - q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True + q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True, sm_scale=None ): """ q, k, v: [b * s, head, head_dim] b_start_loc: [b] b_seq_len: [b] out: [b * s, head, head_dim] + sm_scale: softmax scale, defaults to 1/sqrt(head_dim) """ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8: BLOCK = 128 @@ -183,7 +184,8 @@ def context_attention_fwd( Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - sm_scale = 1.0 / (Lq**0.5) + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 3fd45aac0101..23dba24584e9 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -167,6 +167,7 @@ def __init__( dropout: float = 0.0, flatten_batch: bool = False, softmax_in_single_precision: bool = False, + softmax_scale: float | None = None, **kwargs, ): super().__init__() @@ -176,7 +177,11 @@ def __init__( self.flatten_batch = flatten_batch self.softmax_in_single_precision = softmax_in_single_precision self.dropout = dropout - self.scale = 1.0 / math.sqrt(self.head_size) + self.scale = ( + softmax_scale + if softmax_scale is not None + else 1.0 / math.sqrt(self.head_size) + ) @staticmethod @lru_cache(maxsize=128) @@ -242,6 +247,7 @@ def forward( bsz: int, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -298,6 +304,7 @@ def forward( attn_mask=attention_mask, dropout_p=self.dropout, is_causal=False, + scale=self.scale, ) # [b, h, s, head_size] --> [b * s, h, head_size] @@ -329,11 +336,13 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" Args: cu_seqlens: [b] + softmax_scale: override softmax scale (default 1/sqrt(head_dim)) Returns: [b * s, h, head_size] """ @@ -354,6 +363,7 @@ def forward( cu_seqlens[1], cu_seqlens[2], is_causal=False, + sm_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -372,6 +382,7 @@ def forward( seq_lens.to(q.device), max_seqlen, is_causal=False, + sm_scale=softmax_scale, ) return output @@ -398,6 +409,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -416,6 +428,7 @@ def forward( cu_seqlens_k=cu_seqlens[0], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -431,6 +444,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) return output @@ -453,6 +467,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -482,6 +497,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ver=4, ) @@ -508,6 +524,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -583,7 +600,7 @@ def forward( raise RuntimeError("offset + len out of bounds; packed indptr is wrong") _, _, head_size = q.shape - scale = head_size**-0.5 + scale = softmax_scale if softmax_scale is not None else head_size**-0.5 output, _ = cudnn_batch_prefill_with_kv_cache( q, @@ -635,6 +652,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device=q.device) @@ -651,6 +669,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, ) @@ -672,6 +691,7 @@ def forward( cu_seqlens: torch.Tensor | SingletonCache | None, bsz: int, seq_len: int, + softmax_scale: Optional[float] = None, **kwargs, ) -> torch.Tensor: r""" @@ -684,7 +704,6 @@ def forward( if "output_ws" not in kwargs: raise RuntimeError("output_ws should be prepared for npu-graph mode") output = kwargs["output_ws"] - # graph mode: runner already passes seq_lens (int32 on CPU) seq_len_arg = cu_seqlens else: cu_seqlens = resolve_seqlens(cu_seqlens, bsz, seq_len, device="cpu") @@ -697,12 +716,14 @@ def forward( _, num_heads, head_size = q.shape num_kv_heads = k.shape[1] + scale_value = softmax_scale if softmax_scale is not None else head_size**-0.5 + torch_npu._npu_flash_attention_unpad( query=q, key=k, value=v, seq_len=seq_len_arg, - scale_value=head_size**-0.5, + scale_value=scale_value, num_heads=num_heads, num_kv_heads=num_kv_heads, out=output, @@ -744,6 +765,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, dropout: float = 0.0, softmax_in_single_precision: bool = False, + softmax_scale: Optional[float] = None, flatten_batch: bool = False, prefix: str = "", proj_bias: bool = True, @@ -808,6 +830,7 @@ def __init__( self.customized_position_embedding_applier = ( customized_position_embedding_applier ) + self.softmax_scale = softmax_scale self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( head_dim=self.head_size, num_heads=self.num_attention_heads_per_partition, @@ -815,6 +838,7 @@ def __init__( dropout=dropout, flatten_batch=flatten_batch, softmax_in_single_precision=softmax_in_single_precision, + softmax_scale=softmax_scale, use_data_parallel=use_data_parallel, workspace_buffer=workspace_buffer, ) @@ -1116,6 +1140,7 @@ def forward( sequence_lengths=sequence_lengths, max_seqlen=max_seqlen, output_ws=attn_output_ws, + softmax_scale=self.softmax_scale, ) assert output.dim() == 3, output.shape diff --git a/python/sglang/srt/layers/clippable_linear.py b/python/sglang/srt/layers/clippable_linear.py new file mode 100644 index 000000000000..a253bb42197a --- /dev/null +++ b/python/sglang/srt/layers/clippable_linear.py @@ -0,0 +1,283 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TP-sharded linear wrappers with per-tensor activation clamping. + +Used by the Gemma 4 vision and audio encoders. Each wrapper owns a parallel +linear and four scalar clip buffers (``input_min/max``, ``output_min/max``) +that default to ±inf (no-op) and are populated from the checkpoint. + +For fused projections (QKV, GateUp), input bounds are shared (the checkpoint +stores identical copies per projection — last write wins during loading) and +output bounds are per-projection. +""" + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix + +_INF = float("inf") + + +class ClippableRowParallelLinear(nn.Module): + """``RowParallelLinear`` with input/output activation clamping. + + Checkpoint weight at ``.weight`` is remapped to ``.linear.weight`` + by the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = RowParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableColumnParallelLinear(nn.Module): + """``ColumnParallelLinear`` with input/output activation clamping.""" + + def __init__( + self, + input_size: int, + output_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.linear = ColumnParallelLinear( + input_size=input_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + x, _ = self.linear(x) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableQKVParallelLinear(nn.Module): + """Fused QKV projection with per-projection activation clamping. + + Owns a single ``QKVParallelLinear`` for the fused matmul. Clip bounds + are stored as flat buffers: shared ``input_min/max`` (applied before the + matmul) and per-projection ``q/k/v_output_min/max`` (applied after split). + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.q_size = (total_num_heads // tp_size) * head_size + self.kv_size = (total_num_kv_heads // tp_size) * head_size + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.q_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.q_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.k_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.k_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.v_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.v_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.clamp(hidden_states, self.input_min, self.input_max) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = torch.clamp(q, self.q_output_min, self.q_output_max) + k = torch.clamp(k, self.k_output_min, self.k_output_max) + v = torch.clamp(v, self.v_output_min, self.v_output_max) + return q, k, v + + +class ClippableGLUParallelLinear(nn.Module): + """Fused linear + GLU gating with correct TP sharding. + + Used by the audio encoder's ``LightConv1d``, where a single linear + projects to ``[hidden * 2]`` and GLU splits into value/gate halves. + A plain ``ColumnParallelLinear`` is *incorrect* here under TP because it + shards the output contiguously, mixing value and gate across ranks. + This wrapper uses ``MergedColumnParallelLinear`` to shard each half + independently, then applies GLU (``value * sigmoid(gate)``) on each + rank's correctly-paired shard. + + Output clamping is applied once *after* the GLU gate, using a single + ``output_min/max`` pair (matching the checkpoint layout). + + The checkpoint stores a single fused ``[hidden * 2, input]`` weight. + A custom ``weight_loader`` on the inner param automatically splits it + into value (first half) and gate (second half) shards, so no special + handling is needed in the model's ``load_weights``. + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = hidden_size // tp_size + + self.linear = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[hidden_size, hidden_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("linear", prefix), + ) + + # The checkpoint has a single fused weight; MergedColumnParallelLinear + # expects per-shard loading. Wrap the original weight_loader so that + # a call *without* shard_id (the generic load_weights path) splits + # automatically. + orig_loader = self.linear.weight.weight_loader + + def _fused_weight_loader(param, loaded_weight, loaded_shard_id=None): + if loaded_shard_id is not None: + return orig_loader(param, loaded_weight, loaded_shard_id) + half = loaded_weight.shape[0] // 2 + orig_loader(param, loaded_weight[:half], 0) + orig_loader(param, loaded_weight[half:], 1) + + self.linear.weight.weight_loader = _fused_weight_loader + + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.input_min, self.input_max) + merged, _ = self.linear(x) + value, gate = merged.split([self.proj_size, self.proj_size], dim=-1) + x = value * torch.sigmoid(gate) + x = torch.clamp(x, self.output_min, self.output_max) + return x + + +class ClippableGateUpParallelLinear(nn.Module): + """Fused gate/up projection with per-projection activation clamping. + + Used by the MLP layers in the vision/audio encoders. Owns a single + ``MergedColumnParallelLinear`` for the fused matmul and returns the + two projections separately so the caller can apply its own activation + (e.g. ``SiLU(gate) * up``). + + Output clamping is applied *per-projection before* the caller's + activation, using separate ``gate_output_min/max`` and + ``up_output_min/max`` bounds. + """ + + def __init__( + self, + input_size: int, + intermediate_size: int, + *, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_attention_tp_size() + self.proj_size = intermediate_size // tp_size + + self.gate_up_proj = MergedColumnParallelLinear( + input_size=input_size, + output_sizes=[intermediate_size, intermediate_size], + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.input_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.input_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.gate_output_min = nn.parameter.Buffer( + torch.tensor(-_INF), persistent=False + ) + self.gate_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + self.up_output_min = nn.parameter.Buffer(torch.tensor(-_INF), persistent=False) + self.up_output_max = nn.parameter.Buffer(torch.tensor(_INF), persistent=False) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.clamp(x, self.input_min, self.input_max) + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.split([self.proj_size, self.proj_size], dim=-1) + gate = torch.clamp(gate, self.gate_output_min, self.gate_output_max) + up = torch.clamp(up, self.up_output_min, self.up_output_max) + return gate, up diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py new file mode 100644 index 000000000000..5f227db82853 --- /dev/null +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -0,0 +1,79 @@ +"""Fused triton kernels for Gemma4 decoder layer operations. + +Fuses standard RMSNorm + residual-add (+ optional scalar multiply) into +a single kernel pass to reduce kernel launch overhead. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _gemma_rmsnorm_residual_kernel( + X_ptr, + W_ptr, + Residual_ptr, + Scalar_ptr, + Out_ptr, + stride_x, + stride_r, + stride_o, + N, + eps, + HAS_SCALAR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused kernel: out = rmsnorm(x, w) + residual [* scalar] + + When HAS_SCALAR is True, also multiplies by a scalar loaded from Scalar_ptr. + """ + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + x = tl.load(X_ptr + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W_ptr + cols, mask=mask, other=0.0).to(tl.float32) + r = tl.load(Residual_ptr + row * stride_r + cols, mask=mask, other=0.0).to( + tl.float32 + ) + + var = tl.sum(x * x, axis=0) / N + rrms = tl.rsqrt(var + eps) + out = x * rrms * w + r + + if HAS_SCALAR: + scalar = tl.load(Scalar_ptr).to(tl.float32) + out = out * scalar + + tl.store(Out_ptr + row * stride_o + cols, out.to(x.dtype), mask=mask) + + +def gemma_rmsnorm_residual_scalar( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scalar: torch.Tensor, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused (rmsnorm(x) + residual) * scalar.""" + assert x.dim() == 2 and x.stride(-1) == 1, "Expected contiguous 2D input" + M, N = x.shape + BLOCK_SIZE = triton.next_power_of_2(N) + out = torch.empty_like(x) + + _gemma_rmsnorm_residual_kernel[(M,)]( + x, + weight, + residual, + scalar, + out, + x.stride(0), + residual.stride(0), + out.stride(0), + N, + eps, + HAS_SCALAR=True, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e4960bdb42d6..0db6675e648f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -182,6 +182,11 @@ def forward_cuda( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if x.numel() == 0: return x + # sgl_kernel rmsnorm requires 2D input; reshape higher-rank tensors + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if self.variance_size_override is not None: return self.forward_native(x, residual, post_residual_addition) if is_batch_invariant_mode_enabled(): @@ -205,6 +210,8 @@ def forward_cuda( fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual out = rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_npu( @@ -458,6 +465,10 @@ def _forward_impl( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + needs_reshape = x.dim() != 2 and residual is None + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition @@ -466,6 +477,8 @@ def _forward_impl( ) return x, residual out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + if needs_reshape: + out = out.reshape(original_shape) return out def forward_native( @@ -631,3 +644,88 @@ def forward_npu(self, x): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma4RMSNorm(MultiPlatformOp): + def __init__( + self, + dim: int, + eps: float = 1e-6, + scale_shift: float = 0.0, + with_scale: bool = True, + ): + super().__init__() + self.with_scale = with_scale + + if self.with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_buffer("weight", torch.ones(dim), persistent=False) + + self.eps = eps + self.scale_shift = scale_shift + + def __repr__(self): + dim = self.weight.shape[0] + return ( + f"{self.__class__.__name__}(dim={dim}, eps={self.eps}, " + f"with_scale={self.with_scale}, scale_shift={self.scale_shift})" + ) + + def _norm(self, x): + mean_squared = x.pow(2).mean(-1, keepdim=True) + self.eps + return x * torch.pow(mean_squared, -0.5) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + normed_output = self._norm(x.float()) + if self.with_scale: + normed_output = normed_output * (self.weight.float() + self.scale_shift) + return normed_output.type_as(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if x.numel() == 0: + return x + needs_reshape = x.dim() != 2 + if needs_reshape: + original_shape = x.shape + x = x.contiguous().reshape(-1, original_shape[-1]) + if self.with_scale and self.scale_shift == 1.0: + # gemma_rmsnorm: norm(x) * (1 + weight) + out = gemma_rmsnorm(x, self.weight.data, self.eps) + else: + # rmsnorm: norm(x) * weight + # with_scale=False → weight is ones → norm(x) * 1 = norm(x) + # scale_shift=0.0 → standard RMSNorm without +1 shift + out = rmsnorm(x, self.weight.data, self.eps) + + if needs_reshape: + out = out.reshape(original_shape) + return out + + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + # sgl_kernel's gemma_rmsnorm is not available on ROCm; + # delegate to the pure-PyTorch implementation. + return self.forward_native(x) + + +class RMSNormWithoutScale(MultiPlatformOp): + def __init__(self, hidden_size: int, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward_native(self, x): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return x.to(orig_dtype) + + def forward_cuda(self, x): + return self.forward_native(x) + + def extra_repr(self): + return f"{self.hidden_size}, eps={self.eps}" diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..f0eb57ab8dc0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..60adcf03cea9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..8ff7c371dab5 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000000..48b07c17d5b7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 892dcebaea81..bb6691814cdd 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -956,8 +956,9 @@ def _post_process_topk_ids( topk_ids=topk_ids, ) if _is_cuda: - topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) - _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + topk_ids = _biased_grouped_topk_postprocess( + topk_ids, expert_location_dispatch_info, num_token_non_padded + ) if num_fused_shared_experts > 0 and _use_aiter: M, N = router_logits.shape diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 94f9a1375c14..5fce65159a59 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,7 +1,10 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, List, Optional +logger = logging.getLogger(__name__) + import torch import torch.nn.functional as F from torch.nn.parameter import Parameter @@ -469,28 +472,37 @@ def forward_cuda( topk_weights = torch.ones_like( topk_weights, dtype=torch.float32 ) # topk_weights must be FP32 (float32) - output = fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - expert_mask=layer.expert_mask_gpu, - ) - return StandardCombineInput(hidden_states=output) - else: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - b13=getattr(layer, "w13_weight_bias", None), - b2=getattr(layer, "w2_weight_bias", None), - ) - return self.runner.run(dispatch_output, quant_info) + try: + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + expert_mask=layer.expert_mask_gpu, + ) + return StandardCombineInput(hidden_states=output) + except RuntimeError as e: + # AITER CK fused_moe may not support all GEMM dimensions + # (e.g. Gemma4 MoE with 128 experts × 704 intermediate size). + # Fall through to Triton MoE runner below. + logger.warning_once( + f"AITER CK fused_moe failed ({e}), " + "falling back to Triton MoE runner." + ) + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + b13=getattr(layer, "w13_weight_bias", None), + b2=getattr(layer, "w2_weight_bias", None), + ) + return self.runner.run(dispatch_output, quant_info) def forward_cpu( self, diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 99a3f11ca05f..2ccfdddfc94d 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -106,6 +106,15 @@ def __init__( ) self.position_cos, self.position_sin = None, None + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to diff --git a/python/sglang/srt/layers/rotary_embedding/factory.py b/python/sglang/srt/layers/rotary_embedding/factory.py index 27e28577c96e..d058ea08abb6 100644 --- a/python/sglang/srt/layers/rotary_embedding/factory.py +++ b/python/sglang/srt/layers/rotary_embedding/factory.py @@ -21,6 +21,7 @@ DynamicNTKAlphaRotaryEmbedding, DynamicNTKScalingRotaryEmbedding, FourierRotaryEmbedding, + Gemma4RotaryEmbedding, Llama3RotaryEmbedding, Phi3LongRoPEScaledRotaryEmbedding, ) @@ -326,6 +327,15 @@ def get_rope( long_factor, **extra_kwargs, ) + elif scaling_type == "proportional": + rotary_emb = Gemma4RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/python/sglang/srt/layers/rotary_embedding/rope_variant.py b/python/sglang/srt/layers/rotary_embedding/rope_variant.py index 28aaae598bc8..2fe9d5da280d 100644 --- a/python/sglang/srt/layers/rotary_embedding/rope_variant.py +++ b/python/sglang/srt/layers/rotary_embedding/rope_variant.py @@ -866,3 +866,66 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache + + +class Gemma4RotaryEmbedding(RotaryEmbedding): + """Gemma4-specific RoPE with cross-mixing. + + Instead of rotating the first `rotary_dim` dimensions contiguously, + splits the head into two halves and applies rotation across both. + + For a head_dim of D and rotary_dim of R: + - Standard RoPE rotates: [0, R) + - Gemma4 RoPE rotates: [0, R/2) cross-mixed with [D/2, D/2 + R/2) + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + # Store angles before calling super().__init__ + # rotary_dim is already scaled by partial_rotary_factor in get_rope + # For Gemma4: head_size=512, partial_rotary_factor=0.25 -> rotary_dim=128 + self.rope_angles = rotary_dim // 2 # Number of rotation angles per half + self.nope_angles = (head_size // 2) - self.rope_angles # Non-rotated per half + + super().__init__( + head_size, + head_size, + max_position_embeddings, + base, + is_neox_style, + dtype, + ) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute frequencies only for the rotated dimensions. + + Non-rotated dims are padded with 0.0 to produce identity rotation. + """ + freq_exponents = ( + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size + ) + inv_freq = 1.0 / (base**freq_exponents) + + # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) + if self.nope_angles > 0: + inv_freq = torch.cat( + [ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ] + ) + return inv_freq + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bdf1010dd523..36c55826d821 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -725,9 +725,14 @@ def init_cache_with_memory_pool(self): # Hybrid memory pool self.is_hybrid_swa = self.tp_worker.is_hybrid_swa + _spec = self.tp_worker.model_runner.linear_attn_model_spec + _registry_needs_mamba = ( + _spec.uses_mamba_radix_cache if _spec is not None else False + ) self.is_hybrid_ssm = ( self.tp_worker.model_runner.hybrid_gdn_config is not None or self.tp_worker.model_runner.mamba2_config is not None + or _registry_needs_mamba ) self.sliding_window_size = None diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 96b0e3844914..80f24e6dfa56 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -1,5 +1,4 @@ import logging -import weakref from typing import Dict, List, Optional, Tuple import torch @@ -306,7 +305,7 @@ def __init__( self.clear() self._kvcache = kvcache - self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) + self._kvcache.register_mapping(self.full_to_swa_index_mapping) def available_size(self): return min( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a59742b94354..669cab133c49 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -50,6 +50,7 @@ Qwen3NextConfig, ) from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_config from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import AttentionArch, ModelConfig, ModelImpl from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp @@ -1890,14 +1891,30 @@ def kimi_linear_config(self): return config return None + def _get_linear_attn_registry_result(self): + if not hasattr(self, "_linear_attn_registry_cache"): + self._linear_attn_registry_cache = get_linear_attn_config( + self.model_config.hf_config + ) + return self._linear_attn_registry_cache + + @property + def linear_attn_model_spec(self): + result = self._get_linear_attn_registry_result() + return result[0] if result else None + @property def mambaish_config(self): - return ( + existing = ( self.mamba2_config or self.hybrid_gdn_config or self.kimi_linear_config or self.hybrid_lightning_config ) + if existing: + return existing + result = self._get_linear_attn_registry_result() + return result[1] if result else None def configure_kv_cache_dtype(self): if self.server_args.kv_cache_dtype == "auto": diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index b1935d21c462..6f0ba95a44ff 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -609,7 +609,7 @@ def replay_prepare( buffers.input_ids[num_tokens:static_num_tokens].zero_() buffers.positions[num_tokens:static_num_tokens].zero_() if self.is_multimodal: - buffers.input_embeds[:, num_tokens:static_num_tokens].zero_() + buffers.input_embeds[num_tokens:static_num_tokens].zero_() if forward_batch.mrope_positions is not None: buffers.mrope_positions[:, num_tokens:static_num_tokens].zero_() diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 0481cae0eeba..6a38e7ebad9a 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -95,11 +95,12 @@ def __init__( ) if hidden_activation != "gelu_pytorch_tanh": raise ValueError( - "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + f"{self.__class__.__name__} uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_activation` to " "`gelu_pytorch_tanh`." ) self.act_fn = GeluAndMul() + self.prefix = prefix def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) diff --git a/python/sglang/srt/models/gemma4_audio.py b/python/sglang/srt/models/gemma4_audio.py new file mode 100644 index 000000000000..db825165fe29 --- /dev/null +++ b/python/sglang/srt/models/gemma4_audio.py @@ -0,0 +1,873 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SGLang-native TP-sharded audio encoder for Gemma 4. + +Architecture: Conformer-based USM (Universal Speech Model) with SSCP convolution +projection. Adapted from gemma3n_audio.py with Gemma 4 specific changes: + - Activation clamping (clippable linears) on all conformer linears + - per_dim_key_scale in attention + - LayerNorm (not CumulativeGroupNorm) in SSCP convolution blocks + - Semicausal SSCP padding + - Mask propagation through SSCP + - Output projection (hidden_size -> output_proj_dims) +""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Gemma4AudioConfig + +from sglang.srt.layers.clippable_linear import ( + ClippableColumnParallelLinear, + ClippableGLUParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, +) +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, make_layers, set_weight_attrs + +# SSCP convolution constants (no longer in config.json, never varied across models) +_SSCP_INPUT_FEAT_SIZE = 128 +_SSCP_CONV_KERNEL_SIZES = ((3, 3), (3, 3)) +_SSCP_CONV_STRIDE_SIZES = ((2, 2), (2, 2)) + +# --------------------------------------------------------------------------- +# Relative Position Embedding +# --------------------------------------------------------------------------- + + +class Gemma4AudioRelativePositionEmbedding(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + tp_size = get_attention_tp_size() + total_num_heads = config.num_attention_heads + self.channels = config.hidden_size + self.head_dim = self.channels // total_num_heads + self.num_heads = total_num_heads // tp_size + self.max_backward = max(0, config.attention_context_left - 1) + self.max_forward = config.attention_context_right + + self.pos_proj = ColumnParallelLinear( + self.channels, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("pos_proj", prefix), + ) + + min_timescale = 1.0 + max_timescale = 1.0e4 + num_timescales = self.channels // 2 + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / max(num_timescales - 1, 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales) * -log_timescale_increment + ) + self.register_buffer( + "inv_timescales", + inv_timescales.float().unsqueeze(0).unsqueeze(0), + persistent=False, + ) + + def _get_timing_signal_1d_pos( + self, position: torch.Tensor, dtype: torch.dtype + ) -> torch.Tensor: + assert position.ndim == 2 + position = position.float().unsqueeze(-1) + scaled_time = position * self.inv_timescales.to( + device=position.device, dtype=torch.float32 + ) + timing_signal = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1 + ) + return timing_signal.type(dtype) + + def _relative_shift( + self, + term_bd_before_shift: torch.Tensor, + batch_size: int, + num_heads: int, + num_query_blocks: int, + query_block_size: int, + key_context_size: int, + max_span_plus_1: int, + ) -> torch.Tensor: + pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1 + padding_tuple = (0, pad_amount_last_dim) + + term_bd_padded = F.pad(term_bd_before_shift, padding_tuple) + term_bd_reshaped = term_bd_padded.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size * (key_context_size + 1), + ) + ) + term_bd_sliced = term_bd_reshaped[ + :, :, :, : query_block_size * key_context_size + ] + term_bd_shifted = term_bd_sliced.reshape( + ( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + ) + ) + return term_bd_shifted + + def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor: + batch_size, num_query_blocks, query_block_size, num_heads, head_dim = ( + queries.shape + ) + _, _, key_context_size, _, _ = keys.shape + + pos_indices = torch.arange( + self.max_backward, -self.max_forward - 1, -1, device=queries.device + ).unsqueeze(0) + max_span_plus_1 = pos_indices.shape[1] + + sin_emb_timing_signal = self._get_timing_signal_1d_pos( + pos_indices, dtype=queries.dtype + ) + # pos_proj is a ColumnParallelLinear (no implicit dtype promotion); + # project in weight dtype, then cast back to queries' dtype for the matmuls. + projected_sin_emb, _ = self.pos_proj( + sin_emb_timing_signal.to(self.pos_proj.weight.dtype) + ) + projected_sin_emb = projected_sin_emb.to(queries.dtype) + sin_emb = projected_sin_emb.reshape( + 1, max_span_plus_1, self.num_heads, self.head_dim + ).squeeze(0) + + queries_p = queries.permute(0, 3, 1, 2, 4) + keys_p_t = keys.permute(0, 3, 1, 4, 2) + term_ac = torch.matmul(queries_p, keys_p_t) + + q_permuted = queries.permute(0, 3, 1, 2, 4) + s_permuted = sin_emb.permute(1, 2, 0) + q_reshaped = q_permuted.reshape( + batch_size, num_heads, num_query_blocks * query_block_size, head_dim + ) + term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted) + term_bd_unshifed = term_bd_unshifed_matmul.reshape( + batch_size, + num_heads, + num_query_blocks, + query_block_size, + max_span_plus_1, + ) + + term_bd_shifted = self._relative_shift( + term_bd_unshifed, + batch_size, + num_heads, + num_query_blocks, + query_block_size, + key_context_size, + max_span_plus_1, + ) + + return term_ac + term_bd_shifted + + +# --------------------------------------------------------------------------- +# Local Dot-Product Attention (with per_dim_key_scale) +# --------------------------------------------------------------------------- + + +class Gemma4AudioAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + tp_size = get_attention_tp_size() + total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = self.hidden_size // total_num_heads + self.num_heads = total_num_heads // tp_size + + self.chunk_size = config.attention_chunk_size + self.max_future_horizon = config.attention_context_right + self.max_past_horizon = max(0, config.attention_context_left - 1) + self.attention_logits_soft_cap = config.attention_logit_cap + self.context_size = ( + self.chunk_size + self.max_past_horizon + self.max_future_horizon + ) + + self.relative_position_embedding = Gemma4AudioRelativePositionEmbedding( + config, + quant_config, + prefix=add_prefix("relative_position_embedding", prefix), + ) + self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,))) + + self.qkv = ClippableQKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=total_num_heads, + total_num_kv_heads=total_num_heads, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + + self.q_scale = (self.head_dim**-0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + + self.register_buffer( + "softcap", + torch.tensor(self.attention_logits_soft_cap).float(), + persistent=False, + ) + + # ------ block / context helpers (identical to Gemma3n) ------------------ + + def _pad_dim1( + self, x: torch.Tensor, dim10_val: int, dim11_val: int + ) -> torch.Tensor: + padding_tuple = [0] * x.ndim * 2 + dim_idx_from_end = x.ndim - 2 + start_idx_for_dim = 2 * dim_idx_from_end + padding_tuple[start_idx_for_dim] = dim10_val + padding_tuple[start_idx_for_dim + 1] = dim11_val + return F.pad(x, tuple(padding_tuple)) + + def _convert_to_block(self, x: torch.Tensor) -> torch.Tensor: + shape = x.shape + b, t = shape[:2] + num_blocks = (t + self.chunk_size - 1) // self.chunk_size + if (padding_len := num_blocks * self.chunk_size - t) > 0: + x = self._pad_dim1(x, 0, padding_len) + permute_dims = (b, num_blocks, self.chunk_size) + shape[2:] + return x.reshape(permute_dims).contiguous() + + def _extract_block_context(self, x: torch.Tensor) -> torch.Tensor: + pad_left = self.max_past_horizon + pad_right = self.max_future_horizon + self.chunk_size - 1 + x = self._pad_dim1(x, pad_left, pad_right) + frame_len = self.context_size + frame_step = self.chunk_size + x_unfolded = x.unfold(dimension=1, size=frame_len, step=frame_step) + if x.ndim > 2 and x_unfolded.ndim > 3: + x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2) + return x_unfolded.contiguous() + + # ------ forward --------------------------------------------------------- + + def forward( + self, + x: torch.Tensor, + mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + q, k, v = self.qkv(x) + qkv_shape = (*x.shape[:-1], self.num_heads, self.head_dim) + query_states = q.float().reshape(qkv_shape).contiguous() + key_states = k.float().reshape(qkv_shape).contiguous() + value_states = v.float().reshape(qkv_shape).contiguous() + + per_dim_scale_sp = F.softplus(self.per_dim_scale) + broadcast_shape = (1, 1, 1, self.head_dim) + query_states = ( + query_states * self.q_scale * per_dim_scale_sp.view(broadcast_shape) + ) + + key_states = key_states * self.k_scale + + batch_size, q_time = query_states.shape[:2] + + query_blocks = self._convert_to_block(query_states) + key_blocks = self._extract_block_context(key_states) + value_blocks = self._extract_block_context(value_states) + num_query_blocks = query_blocks.shape[1] + + original_valid_mask = ~mask + extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask) + + if ( + extracted_valid_mask_blocks.ndim == 4 + and extracted_valid_mask_blocks.shape[0] == batch_size + and extracted_valid_mask_blocks.shape[1] == num_query_blocks + and extracted_valid_mask_blocks.shape[2] + * extracted_valid_mask_blocks.shape[3] + == self.context_size + ): + extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape( + batch_size, num_query_blocks, self.context_size + ) + + condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze( + 1 + ).unsqueeze(-2) + condition_from_causality = ( + causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + ) + + final_condition_for_where = torch.logical_and( + condition_from_input_validity, + condition_from_causality.to(condition_from_input_validity.device), + ) + + logits = self.relative_position_embedding(query_blocks, key_blocks) + + softcap_val = self.softcap.to(logits.device) + logits = logits / softcap_val + logits = torch.tanh(logits) + logits = logits * softcap_val + + logits = torch.where( + final_condition_for_where, + logits, + self.config.attention_invalid_logits_value, + ) + + probabilities = F.softmax(logits, dim=-1, dtype=torch.float32).to( + dtype=value_blocks.dtype + ) + + b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape + h_dim = value_blocks.shape[-1] + prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim) + v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim) + result_bmm = torch.bmm(prob_bun, v_bun) + context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute( + 0, 1, 3, 2, 4 + ) + context_vectors = context_vectors.reshape( + batch_size, + num_query_blocks * self.chunk_size, + self.num_heads, + self.head_dim, + ) + context_vectors = context_vectors[:, :q_time] + return context_vectors + + +# --------------------------------------------------------------------------- +# SSCP (Sub-Sample Convolution Projection) +# --------------------------------------------------------------------------- + + +class Gemma4AudioSSCPConvBlock(nn.Module): + """Single 2D conv block with LayerNorm and semicausal padding.""" + + def __init__( + self, + config: Gemma4AudioConfig, + idx: int, + input_freq_dim: int, + ): + super().__init__() + self.config = config + + conv_channels = config.subsampling_conv_channels + in_channels = 1 if idx == 0 else conv_channels[idx - 1] + out_channels = conv_channels[idx] + kernel_t, kernel_f = _SSCP_CONV_KERNEL_SIZES[idx] + stride_t, stride_f = _SSCP_CONV_STRIDE_SIZES[idx] + self.time_stride = stride_t + + # Semicausal padding (hardcoded — streaming is not supported) + pad_t_top = kernel_t // 2 + pad_t_bottom = kernel_t // 2 + + pad_f_left = 1 + pad_f_right = 1 + + self.manual_padding = (pad_f_left, pad_f_right, pad_t_top, pad_t_bottom) + + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_t, kernel_f), + stride=(stride_t, stride_f), + padding=(0, 0), + bias=False, + ) + + f_in_padded = input_freq_dim + pad_f_left + pad_f_right + self.f_out_conv = (f_in_padded - kernel_f) // stride_f + 1 + + self.norm = nn.LayerNorm( + [out_channels], + eps=config.rms_norm_eps, + elementwise_affine=True, + bias=False, + ) + self.activation = nn.ReLU() + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + mask_for_fill = audio_mel_mask.unsqueeze(1).unsqueeze(-1) + audio_encodings = audio_encodings.masked_fill(mask_for_fill, 0.0) + + audio_encodings_padded = F.pad( + audio_encodings, self.manual_padding, mode="constant", value=0.0 + ).to(self.conv.weight.dtype) + audio_encodings_conv = self.conv(audio_encodings_padded) + + output_mask = audio_mel_mask[:, :: self.time_stride][ + :, : audio_encodings_conv.shape[2] + ] + + x = audio_encodings_conv.permute(0, 2, 3, 1) + x_normed = self.norm(x) + audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous() + return self.activation(audio_encodings_normed), output_mask + + +class Gemma4AudioSubSampleConvProjection(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + conv_channels = config.subsampling_conv_channels + + current_f = _SSCP_INPUT_FEAT_SIZE + calculated_f_out_dims = [] + + for i in range(2): + kernel_h, kernel_w = _SSCP_CONV_KERNEL_SIZES[i] + stride_h, stride_w = _SSCP_CONV_STRIDE_SIZES[i] + + pad_f_left = 1 + pad_f_right = 1 + f_in_padded = current_f + pad_f_left + pad_f_right + f_out = (f_in_padded - kernel_w) // stride_w + 1 + calculated_f_out_dims.append(f_out) + current_f = f_out + + self.conv_0 = Gemma4AudioSSCPConvBlock( + idx=0, + input_freq_dim=_SSCP_INPUT_FEAT_SIZE, + config=config, + ) + self.conv_1 = Gemma4AudioSSCPConvBlock( + idx=1, + input_freq_dim=calculated_f_out_dims[0], + config=config, + ) + + final_c_out = conv_channels[-1] + final_f_out = calculated_f_out_dims[-1] + self.input_proj_in_features = final_c_out * final_f_out + + self.input_proj_linear = RowParallelLinear( + self.input_proj_in_features, + config.hidden_size, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + prefix=add_prefix("input_proj_linear", prefix), + ) + + def forward( + self, audio_encodings: torch.Tensor, audio_mel_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + audio_encodings_reshaped = audio_encodings.unsqueeze(1) + x, mask = self.conv_0(audio_encodings_reshaped, audio_mel_mask) + x, mask = self.conv_1(x, mask) + b, c_out, t_out, f_out = x.shape + x_permuted = x.permute(0, 2, 3, 1).contiguous() + output_flattened = x_permuted.reshape(b, t_out, f_out * c_out) + output, _ = self.input_proj_linear(output_flattened) + return output, mask + + +# --------------------------------------------------------------------------- +# Conformer Blocks +# --------------------------------------------------------------------------- + + +class Gemma4AudioConformerAttention(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.post_in_features = config.hidden_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_attn_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.attn = Gemma4AudioAttention( + config, quant_config, prefix=add_prefix("attn", prefix) + ) + self.post = ClippableRowParallelLinear( + self.post_in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("post", prefix), + ) + self.post_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings_input_to_attn = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings_norm = self.pre_attn_norm(audio_encodings) + audio_encodings_attn_out = self.attn( + audio_encodings_norm, audio_mel_mask, causal_valid_mask + ) + + b, t, num_heads, head_dim = audio_encodings_attn_out.shape + audio_encodings_reshaped = audio_encodings_attn_out.reshape( + b, t, num_heads * head_dim + ).to(dtype=audio_encodings_input_to_attn.dtype) + + audio_encodings = self.post(audio_encodings_reshaped) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return audio_encodings_input_to_attn + self.post_norm(audio_encodings) + + +class Gemma4AudioConformerFeedForward(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.ffw_layer_1 = ClippableColumnParallelLinear( + config.hidden_size, + config.hidden_size * 4, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_1", prefix), + ) + self.ffw_layer_2 = ClippableRowParallelLinear( + config.hidden_size * 4, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("ffw_layer_2", prefix), + ) + self.post_layer_norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + self.post_layer_scale = config.residual_weight + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + residual = audio_encodings + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.ffw_layer_1(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.ffw_layer_2(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.post_layer_norm(audio_encodings) + return residual + (audio_encodings * self.post_layer_scale) + + +class Gemma4AudioConformerLightConv1d(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.causal_padding = config.conv_kernel_size - 1 + tp_size = get_attention_tp_size() + hidden_per_tp = config.hidden_size // tp_size + + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + + self.pre_layer_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps, scale_shift=0.0 + ) + self.linear_start = ClippableGLUParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_start", prefix), + ) + self.depthwise_conv1d = nn.Conv1d( + in_channels=hidden_per_tp, + out_channels=hidden_per_tp, + kernel_size=config.conv_kernel_size, + stride=1, + padding=0, + groups=hidden_per_tp, + bias=False, + ) + self.conv_norm = Gemma4RMSNorm( + hidden_per_tp, eps=config.rms_norm_eps, scale_shift=0.0 + ) + + tp_rank = get_attention_tp_rank() + + def _shard_dim0(param, loaded_weight, _rank=tp_rank, _tp=tp_size): + shard = param.shape[0] + loaded_weight = loaded_weight.narrow(0, _rank * shard, shard) + param.data.copy_(loaded_weight) + + set_weight_attrs(self.depthwise_conv1d.weight, {"weight_loader": _shard_dim0}) + set_weight_attrs(self.conv_norm.weight, {"weight_loader": _shard_dim0}) + + self.linear_end = ClippableRowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("linear_end", prefix), + ) + + def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor: + audio_encodings_residual = audio_encodings + + audio_encodings = self.pre_layer_norm(audio_encodings) + audio_encodings = self.linear_start(audio_encodings) + + audio_encodings_permuted = audio_encodings.permute(0, 2, 1) + audio_encodings_permuted_padded = F.pad( + audio_encodings_permuted, (self.causal_padding, 0) + ) + audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded) + audio_encodings = audio_encodings.permute(0, 2, 1) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + audio_encodings = self.conv_norm(audio_encodings) + audio_encodings = F.silu(audio_encodings) + audio_encodings = self.linear_end(audio_encodings) + return audio_encodings + audio_encodings_residual + + +class Gemma4AudioConformerBlock(nn.Module): + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.ffw_layer_start = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_start", prefix) + ) + self.attention = Gemma4AudioConformerAttention( + config, quant_config, prefix=add_prefix("attention", prefix) + ) + self.lconv1d = Gemma4AudioConformerLightConv1d( + config, quant_config, prefix=add_prefix("lconv1d", prefix) + ) + self.ffw_layer_end = Gemma4AudioConformerFeedForward( + config, quant_config, prefix=add_prefix("ffw_layer_end", prefix) + ) + self.register_buffer( + "gradient_clipping", + torch.tensor(config.gradient_clipping), + persistent=False, + ) + self.norm = Gemma4RMSNorm(config.hidden_size, scale_shift=0.0) + + def forward( + self, + audio_encodings: torch.Tensor, + audio_mel_mask: torch.BoolTensor, + causal_valid_mask: torch.BoolTensor, + ) -> torch.Tensor: + audio_encodings = self.ffw_layer_start(audio_encodings) + audio_encodings = self.attention( + audio_encodings, audio_mel_mask, causal_valid_mask + ) + validity_mask_for_lconv = ~audio_mel_mask + audio_encodings_for_lconv_input = ( + audio_encodings + * validity_mask_for_lconv.unsqueeze(-1).to(audio_encodings.dtype) + ) + audio_encodings = self.lconv1d(audio_encodings_for_lconv_input) + + audio_encodings = self.ffw_layer_end(audio_encodings) + audio_encodings = torch.clamp( + audio_encodings, -self.gradient_clipping, self.gradient_clipping + ) + return self.norm(audio_encodings) + + +# --------------------------------------------------------------------------- +# Top-level Encoder +# --------------------------------------------------------------------------- + + +class Gemma4AudioEncoder(nn.Module): + """SGLang-native TP-sharded Gemma 4 audio encoder (USM Conformer + SSCP).""" + + def __init__( + self, + config: Gemma4AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.subsample_conv_projection = Gemma4AudioSubSampleConvProjection( + config, quant_config, prefix=add_prefix("subsample_conv_projection", prefix) + ) + self.conformer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma4AudioConformerBlock( + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("conformer", prefix), + ) + + if config.output_proj_dims is not None: + self.output_proj = RowParallelLinear( + config.hidden_size, + config.output_proj_dims, + bias=True, + input_is_parallel=False, + quant_config=quant_config, + prefix=add_prefix("output_proj", prefix), + ) + else: + self.output_proj = None + + # Precompute causal_valid_mask — depends only on static config values. + chunk_size = config.attention_chunk_size + max_future_horizon = config.attention_context_right + max_past_horizon = max(0, config.attention_context_left - 1) + upper_diagonal = max_past_horizon + max_future_horizon + context_size = chunk_size + max_past_horizon + max_future_horizon + + lower_causal_mask = torch.tril( + torch.ones((context_size, chunk_size), dtype=torch.bool), + diagonal=0, + ).T + upper_causal_mask = torch.tril( + torch.ones((chunk_size, context_size), dtype=torch.bool), + diagonal=upper_diagonal, + ) + local_causal_valid_mask = torch.ones( + (chunk_size, context_size), dtype=torch.bool + ) + self.register_buffer( + "causal_valid_mask", + local_causal_valid_mask * lower_causal_mask * upper_causal_mask, + persistent=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor + ) -> Tuple[torch.Tensor, torch.BoolTensor]: + """Encode a batch of mel spectrograms. + + Args: + audio_mel: [batch, num_frames, mel_bins] + audio_mel_mask: [batch, num_frames], True = padding + + Returns: + audio_encodings: [batch, reduced_frames, hidden_size/output_proj_dims] + audio_mel_mask: [batch, reduced_frames], True = padding + """ + audio_encodings, current_mask = self.subsample_conv_projection( + audio_mel, audio_mel_mask + ) + + for block in self.conformer: + audio_encodings = block( + audio_encodings, current_mask, self.causal_valid_mask + ) + + if self.output_proj is not None: + audio_encodings, _ = self.output_proj(audio_encodings) + + if current_mask.shape[1] != audio_encodings.shape[1]: + target_len = audio_encodings.shape[1] + if target_len > current_mask.shape[1]: + current_mask = F.pad( + current_mask, (0, target_len - current_mask.shape[1]), value=True + ) + else: + current_mask = current_mask[:, :target_len] + + audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0) + return audio_encodings, current_mask diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py new file mode 100644 index 000000000000..544406119243 --- /dev/null +++ b/python/sglang/srt/models/gemma4_causal.py @@ -0,0 +1,1009 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import re +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import ( + Gemma4TextConfig, + PretrainedConfig, + PreTrainedModel, +) + +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.gemma4_fused_ops import gemma_rmsnorm_residual_scalar +from sglang.srt.layers.layernorm import Gemma4RMSNorm, RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma3_causal import Gemma3MLP, Gemma3TextScaledWordEmbedding +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix, make_layers + +logger = logging.getLogger(__name__) + + +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + +Gemma4MLP = Gemma3MLP +Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding + + +class Gemma4Router(nn.Module): + """Router for Gemma4 MoE that preprocesses input before projection. + + Applies RMSNorm (no learned weight), root_size scaling + (hidden_size^{-0.5}), then a learned per-dimension scale before + projecting to expert logits. + + This preprocessing is applied ONLY to the router's input, not to + the expert MLPs' input. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + # RMSNorm without learned weight — pure normalization only + self.norm = Gemma4RMSNorm( + self.hidden_size, eps=config.rms_norm_eps, with_scale=False + ) + # Per-dimension learned scale, applied after norm + root_size + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + # Constant 1/sqrt(hidden_size) scaling factor + self.register_buffer( + "root_size", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + # Project to expert logits; replicated across TP for consistent routing + self.proj = ReplicatedLinear( + self.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("proj", prefix), + ) + self._fused_scale: Optional[torch.Tensor] = None + + def fuse_scale(self): + """Pre-compute scale * root_size. Call after weights are loaded.""" + self._fused_scale = (self.scale * self.root_size).to(self.scale.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns raw router logits [T, E].""" + x = self.norm(x) + if self._fused_scale is None: + self.fuse_scale() + x = x * self._fused_scale.to(x.dtype) + router_logits, _ = self.proj(x) + return router_logits + + +class Gemma4MoE(nn.Module): + """Mixture of Experts for Gemma4. + + Wraps MoE implementation with custom routing. The router projection is + external (Gemma4Router) — this class only handles expert dispatch. + + Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale is folded into routing weights for mathematical + correctness with MoE's fused kernel. + """ + + def __init__( + self, + hidden_size: int, + layer_id: int, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.num_experts = config.num_experts + self.tp_size = get_tensor_model_parallel_world_size() + + # Per-expert output scale folded into routing weights so that + # MoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + # Capture param directly to avoid closing over self in the routing closure. + per_expert_scale = self.per_expert_scale + + def routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, # always True for Gemma4; softmax identity only holds when renormalizing + ) -> tuple[torch.Tensor, torch.Tensor]: + # softmax(all)[topk] / sum(softmax(all)[topk]) = softmax(topk_logits), + # so we softmax only the top-k logits (fewer kernel launches). + topk_logits, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + topk_weights = torch.nn.functional.softmax(topk_logits, dim=-1) + + # Fold per_expert_scale into routing weights + topk_weights = topk_weights * per_expert_scale[topk_ids].to( + topk_weights.dtype + ) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + self.topk = TopK( + top_k=config.top_k_experts, + layer_id=layer_id, + custom_routing_function=routing_function, + ) + + experts_type = get_moe_impl_class(quant_config) + + self.experts = experts_type( + num_experts=config.num_experts + + get_global_server_args().ep_num_redundant_experts, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + layer_id=layer_id, + top_k=config.top_k_experts, + quant_config=quant_config, + prefix=add_prefix("experts", prefix), + activation="gelu", + reduce_results=True, + ) + + def forward( + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + topk_output = self.topk(hidden_states, router_logits) + hidden_states = self.experts(hidden_states, topk_output) + return hidden_states.view(num_tokens, hidden_dim) + + +class Gemma4Attention(nn.Module): + def __init__( + self, + layer_id: int, + config: Gemma4TextConfig, + head_dim: int, + max_position_embeddings: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_id = layer_id + self.config = config + tp_size = get_tensor_model_parallel_world_size() + + layer_type = config.layer_types[layer_id] + self.sliding_window = ( + config.sliding_window if layer_type == "sliding_attention" else None + ) + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + if layer_type == "sliding_attention": + self.total_num_kv_heads = getattr( + config, "swa_num_key_value_heads", config.num_key_value_heads + ) + else: + self.total_num_kv_heads = config.num_key_value_heads + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + hidden_size = config.hidden_size + self.head_dim = head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.k_norm = Gemma4RMSNorm( + self.head_dim, + eps=config.rms_norm_eps, + ) + self.v_norm = Gemma4RMSNorm( + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False + ) + + if layer_type in config.rope_parameters: + rope_parameters = dict(config.rope_parameters[layer_type]) + else: + rope_parameters = dict( + rope_type="default", + rope_theta=10000.0, + ) + + # KV sharing logic + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = ( + layer_id >= first_kv_shared_layer_idx and num_kv_shared_layers > 0 + ) + + self.kv_shared_layer_index = None + if num_kv_shared_layers > 0 and self.layer_id >= first_kv_shared_layer_idx: + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[self.layer_id] + if current_layer_type not in prev_layers: + raise ValueError( + f"KV sharing layer {self.layer_id} has type '{current_layer_type}' " + f"but no matching type found in layers 0..{first_kv_shared_layer_idx - 1}. " + f"Available types: {set(prev_layers)}" + ) + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_parameters.get("rope_theta", 10000.0), + rope_scaling={"rope_type": rope_parameters.get("rope_type", "default")}, + partial_rotary_factor=rope_parameters.get("partial_rotary_factor", 1.0), + is_neox_style=True, + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + 1, # scaling factor + num_kv_heads=self.num_kv_heads, + layer_id=( + self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id + ), + logit_cap=0.0, + sliding_window_size=self.sliding_window, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + # Check if we should use shared KV cache + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None: + # For KV shared layers, we skip K/V computation and normalization + # The RadixAttention will handle retrieving shared KV from cache + k = None + v = None + else: + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + + # Apply rotary embedding + if k is not None: + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + else: + # Rotary embedding requires a key input; use zeros since KV is shared from another layer + dummy_k = torch.zeros_like(q[:, : self.kv_size]) + q, _ = self.rotary_emb(positions, q, dummy_k) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + attn_output = self.attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=not self.is_kv_shared_layer, + ) + if attn_output.dim() == 3: + attn_output = attn_output.flatten(-2, -1) + output, _ = self.o_proj(attn_output) + + return output + + +class Gemma4DecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) + + self.layer_id = layer_id + + # Gemma 4 uses different head dimensions for sliding vs full attention + layer_type = config.layer_types[layer_id] + self.is_full_attention = layer_type == "full_attention" + if self.is_full_attention: + head_dim = config.head_dim # following sglang naming + else: + head_dim = getattr(config, "swa_head_dim", config.head_dim) + + self.self_attn = Gemma4Attention( + layer_id=layer_id, + config=config, + max_position_embeddings=config.max_position_embeddings, + head_dim=head_dim, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + is_kv_shared_layer = self.layer_id >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = ( + getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer + ) + layer_intermediate_size = config.intermediate_size * ( + 2 if use_double_wide_mlp else 1 + ) + + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=layer_intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # Per-Layer Embedding (PLE) components — present in each decoder layer + if self.hidden_size_per_layer_input > 0: + # Gate: projects hidden_states → per-layer dim for gating + self.per_layer_input_gate = ReplicatedLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_input_gate", prefix), + ) + # Projection: projects gated per-layer input back → hidden size + self.per_layer_projection = ReplicatedLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_projection", prefix), + ) + self.post_per_layer_input_norm = Gemma4RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.per_layer_input_gate = None + self.per_layer_projection = None + self.post_per_layer_input_norm = None + + # Parallel MoE + self.enable_moe_block = getattr(config, "enable_moe_block", False) + if self.enable_moe_block: + self.router = Gemma4Router( + config, + quant_config=quant_config, + prefix=add_prefix("router", prefix), + ) + self.moe = Gemma4MoE( + hidden_size=self.hidden_size, + layer_id=layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("moe", prefix), + ) + + self.post_feedforward_layernorm_1 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.router = None + self.moe = None + self.post_feedforward_layernorm_1 = None + self.post_feedforward_layernorm_2 = None + self.pre_feedforward_layernorm_2 = None + + self.register_buffer("layer_scalar", torch.ones(1), persistent=True) + self.has_ple = self.hidden_size_per_layer_input > 0 + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + # Gemma4 residual pattern following JAX implementation: + # 1. input_norm(x) -> attn -> post_attn_norm -> ADD residual + # 2. pre_ff_norm -> mlp -> post_ff_norm -> ADD residual + # + # Optimization: fuse "post_attn_norm(h) + residual; pre_ff_norm(...)" + # into "post_attn_norm(h); pre_ff_norm(h, residual)" using + # gemma_fused_add_rmsnorm which computes: + # residual = h + residual (in-place) + # h = gemma_norm(residual) + residual = hidden_states + + # Apply input layernorm + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.enable_moe_block: + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + # Also need raw (unfused) residual for router and pre_ff_norm_2 + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + # For MoE: router and pre_ff_norm_2 need the unfused residual + # (which is now updated to post_attn_out + old_residual) + moe_input = residual + + # Dense MLP branch + hidden_states_1 = self.mlp(hidden_states) + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states_1) + + # MoE branch: router sees residual (= post_attn_out + old_residual) + router_logits = self.router(moe_input) + hidden_states_2 = self.pre_feedforward_layernorm_2(moe_input) + hidden_states_2 = self.moe(hidden_states_2, router_logits) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine branches + hidden_states = hidden_states_1 + hidden_states_2 + else: + # Fuse: hidden_states + residual -> residual; pre_ff_norm(residual) -> hidden_states + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + + if not self.has_ple and hidden_states.is_cuda and hidden_states.dim() == 2: + # Fused: (post_ff_norm(h) + residual) * layer_scalar in one kernel + norm = self.post_feedforward_layernorm + hidden_states = gemma_rmsnorm_residual_scalar( + hidden_states, + norm.weight.data, + residual, + self.layer_scalar, + norm.variance_epsilon, + ) + else: + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + if self.has_ple and per_layer_input is not None: + gate, _ = self.per_layer_input_gate(hidden_states) + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution, _ = self.per_layer_projection(gated_per_layer) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + hidden_states = hidden_states * self.layer_scalar + return hidden_states, None + + +class Gemma4TextModel(PreTrainedModel): + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.vocab_size = config.vocab_size + self.padding_idx = getattr(config, "pad_token_id", None) + + self.embed_tokens = Gemma4TextScaledWordEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + embed_scale=self.config.hidden_size**0.5, # embedded normalizer + ) + + # Per-layer input embeddings + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = ( + getattr(config, "hidden_size_per_layer_input", None) or 0 + ) + self.vocab_size_per_layer_input = ( + getattr(config, "vocab_size_per_layer_input", None) or config.vocab_size + ) + + if self.hidden_size_per_layer_input and self.hidden_size_per_layer_input > 0: + self.embed_tokens_per_layer = Gemma4TextScaledWordEmbedding( + self.vocab_size_per_layer_input, + config.num_hidden_layers * self.hidden_size_per_layer_input, + self.padding_idx, + embed_scale=self.hidden_size_per_layer_input**0.5, + ) + + self.per_layer_model_projection = ReplicatedLinear( + self.hidden_size, + config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=add_prefix("per_layer_model_projection", prefix), + ) + + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + config.rms_norm_eps, + ) + self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)) + self.per_layer_projection_scale = torch.tensor( + config.hidden_size**-0.5, + ) + else: + self.embed_tokens_per_layer = None + self.per_layer_model_projection = None + self.per_layer_projection_norm = None + self.per_layer_input_scale = None + self.per_layer_projection_scale = None + + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Gemma4DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=add_prefix("layers", prefix), + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.embed_tokens + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: + if self.embed_tokens_per_layer is None: + return None + + # Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may + # be smaller than the main vocab_size). Following Gemma3n pattern. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + + # Get packed per-layer embeddings: (num_tokens, total_ple_dim) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + + # Apply embed_scale (sqrt of per-layer hidden dim) + # Already done in embedding layer + # per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_embeds = per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + return per_layer_embeds + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project inputs_embeds and combine with per_layer_inputs. + + Following HF/Gemma3n reference: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} (Gemma4ScaledLinear w_scale) + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + + # Project from hidden_size to total_ple_dim + per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds) + + # Apply w_scale (HF: Gemma4ScaledLinear with w_scale=hidden_size^{-0.5}) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + # Normalize + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + # Combine: (projection + per_layer_inputs) * scale + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if input_ids is not None: + input_embeds = self.embed_tokens(input_ids) + per_layer_inputs = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs) + + hidden_states = input_embeds + + for layer_idx, layer in enumerate(self.layers): + if per_layer_inputs is not None: + per_layer_input = per_layer_inputs[:, layer_idx, :] + else: + per_layer_input = None + layer_outputs = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_input, + forward_batch=forward_batch, + **kwargs, + ) + hidden_states = layer_outputs[0] + residual = layer_outputs[1] if len(layer_outputs) > 1 else None + + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Gemma4ForCausalLM(PreTrainedModel): + config_class = Gemma4TextConfig + base_model_prefix = "language_model" + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_rep"} + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = False + + def __init__( + self, + config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + self.model = Gemma4TextModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.logits_processor = LogitsProcessor(config) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + per_layer_inputs: Optional[torch.Tensor] = None, + **kwargs, + ) -> LogitsProcessor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + per_layer_inputs, + **kwargs, + ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + if not getattr(self.config, "attention_k_eq_v", False): + return set() + return { + i for i, lt in enumerate(self.config.layer_types) if lt == "full_attention" + } + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] + num_experts = self.config.num_experts + + k_eq_v_layers = self._get_k_eq_v_layers() + + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + name = name.replace("model.language_model.", "model.") + + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") + + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). + orig_name = name + for param_name, weight_name, shard_ids in expert_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) + break + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") + break + else: + name = orig_name + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) + return loaded_params + + +EntryClass = Gemma4ForCausalLM diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py new file mode 100644 index 000000000000..4618129fab7a --- /dev/null +++ b/python/sglang/srt/models/gemma4_mm.py @@ -0,0 +1,878 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import logging +import re +from functools import lru_cache +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +from torch import nn +from transformers import ( + Gemma4AudioConfig, + Gemma4Config, + Gemma4TextConfig, + Gemma4VisionConfig, + PreTrainedModel, +) + +from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, + flatten_nested_list, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder +from sglang.srt.models.gemma4_causal import Gemma4TextModel +from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder +from sglang.srt.utils import add_prefix +from sglang.srt.utils.hf_transformers_utils import get_processor + +logger = logging.getLogger(__name__) + +cached_get_processor = lru_cache(get_processor) + + +class Gemma4ImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Gemma4AudioInputs(TypedDict): + input_features_padded: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length, num_features)`""" + input_features_mask: torch.Tensor + """Shape: `(batch_size * num_audio, seq_length)`""" + + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects vision/audio soft tokens into LM embedding space.""" + + def __init__( + self, + multimodal_config: Union[Gemma4AudioConfig, Gemma4VisionConfig], + text_config: Gemma4TextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + + # Audio tower uses output_proj_dims (1536) rather than hidden_size + # (1024); vision uses hidden_size (768) directly. + embedding_dim = ( + getattr(multimodal_config, "output_proj_dims", None) + or multimodal_config.hidden_size + ) + + self.embedding_projection = ReplicatedLinear( + embedding_dim, + self.text_hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("embedding_projection", prefix), + ) + + self.embedding_pre_projection_norm = Gemma4RMSNorm( + embedding_dim, + eps=self.eps, + with_scale=False, + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + """Project soft tokens from a multimodal tower into LM space.""" + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + embs_proj, _ = self.embedding_projection(embs_normed) + return embs_proj + + +class Gemma4ForConditionalGeneration(PreTrainedModel): + config_class = Gemma4Config + """Gemma4 multimodal model for conditional generation.""" + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer + embedding_modules = {} + embedding_padding_modules = [] + supports_lora = True + + def __init__( + self, + config: Gemma4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config) + self.config = config + self.quant_config = quant_config + + prefix = add_prefix("model", prefix) + + self.vision_tower = Gemma4VisionEncoder( + config=config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) + + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_vision", prefix), + ) + + # Audio components + if getattr(config, "audio_config", None) is not None: + self.audio_tower = Gemma4AudioEncoder( + config=config.audio_config, + quant_config=quant_config, + prefix=add_prefix("audio_tower", prefix), + ) + self.embed_audio = Gemma4MultimodalEmbedder( + config.audio_config, + config.text_config, + quant_config=quant_config, + prefix=add_prefix("embed_audio", prefix), + ) + else: + self.audio_tower = None + self.embed_audio = None + + self.vocab_size = config.text_config.vocab_size + self.vocab_size_per_layer_input = getattr( + config.text_config, + "vocab_size_per_layer_input", + config.text_config.vocab_size, + ) + + # Text model + self.language_model = Gemma4TextModel( + config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + + # Create logits processor for the multimodal model + self.logits_processor = LogitsProcessor(config.text_config) + + self.post_init() + + def pad_input_ids( + self, + input_ids: List[int], + mm_inputs: MultimodalInputs, + ) -> List[int]: + """Pad input IDs with image and audio tokens.""" + pattern = MultiModalityDataPaddingPatternMultimodalTokens() + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_input_embeddings(self) -> nn.Embedding: + return self.language_model.get_input_embeddings() + + def get_attention_sliding_window_size(self): + return getattr(self.config.text_config, "sliding_window", -1) - 1 + + def prepare_attn_masks( + self, + forward_batch: ForwardBatch, + input_ids: torch.Tensor, + mask_dtype: torch.dtype, + ): + """Prepare bidirectional attention masks for image tokens. + + Gemma 4 uses bidirectional attention for image soft tokens + during prefill. Following the HF implementation, bidirectional attention + is only enabled within each individual image group (same-item + tokens), not across items. + Currently only the TritonAttnBackend supports this. + + TODO(kpham-sgl): Guard appropriately for gemma3_mm.py:prepare_attn_masks() + """ + if not isinstance(forward_batch.attn_backend, TritonAttnBackend): + logger.warning_once( + "Bidirectional attention for image tokens requires TritonAttnBackend. " + "Falling back to causal attention, which may degrade image quality." + ) + return + assert forward_batch.forward_mode == ForwardMode.EXTEND + + bidirectional_attn_masks_list = [] + bidirectional_attn_mask_indptr = torch.zeros( + forward_batch.batch_size + 1, dtype=torch.int32, device=input_ids.device + ) + + split_images = [] + + for i in range(forward_batch.batch_size): + extend_seq_len = forward_batch.extend_seq_lens[i] + prefix_len = forward_batch.extend_prefix_lens[i] + bidirectional_attn_mask = torch.zeros( + extend_seq_len, + extend_seq_len + prefix_len, + dtype=mask_dtype, + device=input_ids.device, + ) + # Start with causal mask + bidirectional_attn_mask.fill_(1) + bidirectional_attn_mask = bidirectional_attn_mask.tril(diagonal=prefix_len) + + # HF only enables bidirectional attention for image tokens, + # not video or audio (see create_causal_mask_mapping). + mm_inputs = forward_batch.mm_inputs[i] + if mm_inputs is not None: + for mm_item in mm_inputs.mm_items: + if mm_item.is_image(): + for im_begin, im_end in mm_item.offsets: + # Note(kpham-sgl): We only apply bidirectional attention when the image token span + # is fully contained in the extend window. Otherwise, we silently fall back to + # causal attention. + # FIXME(kpham-sgl): This is a hack to work around the fact that the image token span + # might not be fully contained in the extend window during chunked prefill. + # We should fix this by properly making chunked prefill mask aware. + if ( + im_begin >= prefix_len + and im_end < prefix_len + extend_seq_len + ): + bidirectional_attn_mask[ + im_begin - prefix_len : im_end + 1 - prefix_len, + im_begin : im_end + 1, + ] = 1 + elif ( + im_end >= prefix_len + and im_begin < prefix_len + extend_seq_len + ): + split_images.append((i, im_begin, im_end)) + + bidirectional_attn_masks_list.append(bidirectional_attn_mask.flatten()) + bidirectional_attn_mask_indptr[i + 1] = ( + bidirectional_attn_mask_indptr[i] + bidirectional_attn_mask.nelement() + ) + if split_images: + num_split_images = len(split_images) + logger.warning_once( + f"{num_split_images} images are split across chunk boundaries. " + "Below are the first 5 images that are split across chunk boundaries: " + ) + for i, im_begin, im_end in split_images[:5]: + logger.warning_once( + f"Image {i}:{im_begin}-{im_end} is split across chunk boundaries.\n", + ) + logger.warning_once( + "Those images will receive causal attention. Disable chunked prefill (--chunked-prefill-size=-1) for full bidirectional attention.", + ) + if bidirectional_attn_masks_list: + bidirectional_attn_masks = torch.cat(bidirectional_attn_masks_list, dim=0) + forward_batch.attn_backend.forward_metadata.mask_indptr = ( + bidirectional_attn_mask_indptr + ) + forward_batch.attn_backend.forward_metadata.custom_mask = ( + bidirectional_attn_masks + ) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + vt = self.vision_tower + + all_embeds = [] + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "image_position_ids", None)] + ) + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: + raise ValueError( + f"pixel_values[{pv_idx}] has no matching image_position_ids. " + "The HF image processor likely renamed this output — " + "update ATTR_NAME_TO_MODALITY in the Gemma4 processor." + ) + pp = all_position_ids[pv_idx] + + # Vision tower expects 3-D (batch, num_patches, ...). + # A single image may arrive as 2-D; add the batch dim if needed. + if pv.dim() == 2: + pv = pv.unsqueeze(0) + if pp.dim() == 2: + pp = pp.unsqueeze(0) + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + pp = pp.to(device=vt.device) + + pooled, pooler_mask = vt(pv, pp) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """Encode video frames through the vision tower with video-specific pooling. + + Each video is (num_frames, num_patches, patch_pixels) with matching + position_ids (num_frames, num_patches, 2). Frames are flattened into + the batch dimension so each frame is encoded independently, then pooled + dynamically based on the input patch count and pooling_kernel_size. + """ + vt = self.vision_tower + + all_embeds = [] + for item in items: + all_pixel_values = flatten_nested_list([item.feature]) + all_position_ids = flatten_nested_list( + [getattr(item, "video_position_ids", None)] + ) + + for pv_idx, pv in enumerate(all_pixel_values): + if ( + pv.dim() in (2, 3) + and pv.shape[-1] == self.config.text_config.hidden_size + ): + all_embeds.append(pv.to(self.language_model.device)) + continue + + if pv_idx >= len(all_position_ids) or all_position_ids[pv_idx] is None: + raise ValueError( + f"pixel_values_videos[{pv_idx}] has no matching video_position_ids." + ) + pp = all_position_ids[pv_idx] + + # HF processor returns 4-D tensors + # (num_videos, num_frames, num_patches, ...) — collapse to + # 3-D (num_frames, num_patches, ...) so each frame is a + # batch element for the vision tower. + if pv.dim() == 4: + pv = pv.reshape(-1, pv.shape[-2], pv.shape[-1]) + if pp.dim() == 4: + pp = pp.reshape(-1, pp.shape[-2], pp.shape[-1]) + + pv = pv.to(device=vt.device, dtype=self.language_model.dtype()) + pp = pp.to(device=vt.device) + + pooled, pooler_mask = vt(pv, pp) + + for hs, mask in zip(pooled, pooler_mask): + real_tokens = hs[mask] + all_embeds.append( + self.embed_vision( + inputs_embeds=real_tokens.unsqueeze(0) + ).squeeze(0) + ) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + if self.audio_tower is None: + raise ValueError( + "Audio inputs provided but the model does not have an audio tower." + ) + + all_input_features = flatten_nested_list([item.feature for item in items]) + all_input_features_mask = flatten_nested_list( + [~item.input_features_mask for item in items] + ) + + all_embeds = [] + for input_features, input_features_mask in zip( + all_input_features, all_input_features_mask + ): + if input_features.dim() == 2: + input_features = input_features.unsqueeze(0) + if input_features_mask.dim() == 1: + input_features_mask = input_features_mask.unsqueeze(0) + + input_features = input_features.to( + device=self.audio_tower.device, + dtype=self.language_model.dtype(), + ) + input_features_mask = input_features_mask.to(device=input_features.device) + + # audio_mel_mask convention: True = padding + audio_encodings, audio_mask = self.audio_tower( + input_features, input_features_mask + ) + + audio_features = self.embed_audio(inputs_embeds=audio_encodings) + + for enc, mask in zip(audio_features, audio_mask): + all_embeds.append(enc[~mask]) + + if all_embeds: + return torch.cat(all_embeds, dim=0) + else: + return torch.empty( + 0, + self.language_model.config.hidden_size, + device=next(self.parameters()).device, + dtype=self.language_model.dtype(), + ) + + def get_per_layer_inputs( + self, input_ids: torch.LongTensor + ) -> Optional[torch.Tensor]: + return self.language_model.get_per_layer_inputs(input_ids) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.language_model.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs: object, + ) -> LogitsProcessor: + """Forward pass for multimodal Gemma4.""" + if (input_ids is None) ^ (input_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + positions += 1 + per_layer_inputs = None + if input_ids is not None: + ple_ids = input_ids.clone() + pad_id = self.config.text_config.pad_token_id + ple_ids[input_ids == self.config.image_token_id] = pad_id + ple_ids[input_ids == self.config.video_token_id] = pad_id + ple_ids[input_ids == self.config.audio_token_id] = pad_id + per_layer_inputs = self.get_per_layer_inputs(ple_ids) + + # Prepare bidirectional attention masks for image tokens during prefill. + # Gemma 4 uses bidirectional attention for image soft tokens. + # Only TritonAttnBackend supports this; incompatible with CUDA Graph and + # chunked prefill. + if ( + forward_batch.forward_mode == ForwardMode.EXTEND + and forward_batch.contains_image_inputs() + ): + self.prepare_attn_masks( + forward_batch, + input_ids, + mask_dtype=torch.bool, + ) + + # Use general_mm_embed_routine for handling multimodal data + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + Modality.VIDEO: self.get_video_feature, + Modality.AUDIO: self.get_audio_feature, + }, + positions=positions, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + # Process hidden states through logits processor + return self.logits_processor( + input_ids, hidden_states, self.language_model.embed_tokens, forward_batch + ) + + def tie_weights(self, recompute_mapping=False): + return self.language_model.tie_weights() + + # Standard stacked-params mapping for fused QKV / GateUp linears + # in the text decoder. Also consumed by the tower QKV remap (step 2). + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + + # Regex for fused QKV in vision/audio towers. + # Vision: *.self_attn.{q,k,v}_proj.* Audio: *.attn.{q,k,v}_proj.* + _RE_TOWER_QKV = re.compile( + r"(.+\.(?:self_attn|attn))\.(q_proj|k_proj|v_proj)\.(.*)" + ) + # Regex for fused GateUp in the vision tower MLP. + _RE_TOWER_GATE_UP = re.compile(r"(.+\.mlp)\.(gate_proj|up_proj)\.(.*)") + + _RE_AUDIO_LAYER = re.compile(r"(audio_tower)\.layers\.(\d+)\.(.*)") + + @staticmethod + def _remap_audio_tower_name(name: str) -> str: + """Remap audio tower checkpoint names to our module tree. + + Checkpoint naming (``layers``, ``self_attn``, ``feed_forward1/2``, etc.) + differs from our module tree (``conformer``, ``attention.attn``, + ``ffw_layer_start/end``, etc.). Applied before ``_remap_tower_name``. + """ + if "audio_tower." not in name: + return name + + # SSCP conv block: layer0/layer1 → conv_0/conv_1 + name = name.replace( + "subsample_conv_projection.layer0.", + "subsample_conv_projection.conv_0.", + ) + name = name.replace( + "subsample_conv_projection.layer1.", + "subsample_conv_projection.conv_1.", + ) + + # Conformer layers: audio_tower.layers.{i} → audio_tower.conformer.{i} + m = Gemma4ForConditionalGeneration._RE_AUDIO_LAYER.match(name) + if m: + tower, layer_idx, suffix = m.groups() + + # Order matters: more specific patterns first. + # relative_k_proj → relative_position_embedding.pos_proj + suffix = suffix.replace( + "self_attn.relative_k_proj.", + "attention.attn.relative_position_embedding.pos_proj.", + ) + # self_attn.post → attention.post (the output projection) + suffix = suffix.replace("self_attn.post.", "attention.post.") + # general self_attn → attention.attn + suffix = suffix.replace("self_attn.", "attention.attn.") + # norms + suffix = suffix.replace("norm_pre_attn.", "attention.pre_attn_norm.") + suffix = suffix.replace("norm_post_attn.", "attention.post_norm.") + suffix = suffix.replace("norm_out.", "norm.") + # feed-forward blocks + suffix = suffix.replace("feed_forward1.", "ffw_layer_start.") + suffix = suffix.replace("feed_forward2.", "ffw_layer_end.") + + name = f"{tower}.conformer.{layer_idx}.{suffix}" + + return name + + @staticmethod + def _remap_tower_name(name: str, params_dict: dict) -> str: + """Remap a vision/audio tower checkpoint name to our module tree. + + Three transformations, applied in order: + + 1. **Fused QKV** — ``{q,k,v}_proj.*`` → ``qkv.*`` + Weight/bias are redirected into the fused ``qkv.{proj}.{attr}`` + namespace (stacked-params then merges them into ``qkv_proj``). + Clip buffers are split: ``input_*`` → shared ``qkv.input_*``, + ``output_*`` → per-projection ``qkv.{q,k,v}_output_*``. + + 2. **Fused GateUp** — ``{gate,up}_proj.*`` → ``gate_up.*`` + Same pattern as QKV. + + 3. **Clippable wrapper** — ``*.weight``/``*.bias`` → ``*.linear.weight`` + Catches the remaining (non-fused) clippable linears whose inner + ``RowParallelLinear``/``ColumnParallelLinear`` lives at ``.linear``. + Falls back to the original name when ``.linear.`` does not exist + in ``params_dict`` (plain linears, norms, conv weights, etc.). + """ + # Step 1: fused QKV + m = Gemma4ForConditionalGeneration._RE_TOWER_QKV.match(name) + if m: + pfx, proj, attr = m.groups() + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.qkv.{proj}.{bare_attr}" + if attr.startswith("output_"): + return f"{pfx}.qkv.{proj[0]}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.qkv.{attr}" + + # Step 2: fused GateUp + m = Gemma4ForConditionalGeneration._RE_TOWER_GATE_UP.match(name) + if m: + pfx, proj, attr = m.groups() + short = proj.split("_")[0] # "gate" or "up" + if attr in ("weight", "bias", "linear.weight", "linear.bias"): + bare_attr = attr.rsplit(".", 1)[-1] + return f"{pfx}.gate_up.{proj}.{bare_attr}" + if attr.startswith("output_"): + return f"{pfx}.gate_up.{short}_{attr}" + if attr.startswith("input_"): + return f"{pfx}.gate_up.{attr}" + + # Step 3: clippable wrapper (.weight → .linear.weight) + if name.endswith(".weight") or name.endswith(".bias"): + base, attr = name.rsplit(".", 1) + alt = f"{base}.linear.{attr}" + if alt in params_dict: + return alt + + return name + + def _get_k_eq_v_layers(self) -> set: + """Return set of layer indices where attention_k_eq_v applies (full-attention layers).""" + text_config = self.config.text_config + if not getattr(text_config, "attention_k_eq_v", False): + return set() + return { + i for i, lt in enumerate(text_config.layer_types) if lt == "full_attention" + } + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + k_eq_v_layers = self._get_k_eq_v_layers() + + num_experts = getattr(self.config.text_config, "num_experts", 0) or 0 + expert_params_mapping = [ + # (param_name, ckpt_weight_name, shard_ids) + # gate_up_proj is fused [E, 2*I, H] — chunk into w1 (gate) + w3 (up) + ("experts.w13_weight", "experts.gate_up_proj", ("w1", "w3")), + ("experts.w2_weight", "experts.down_proj", ("w2",)), + ] + + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + non_persistent_buffers: Set[str] = set() + for mod_name, mod in self.named_modules(): + for buf_name in getattr(mod, "_non_persistent_buffers_set", set()): + full = f"{mod_name}.{buf_name}" if mod_name else buf_name + non_persistent_buffers.add(full) + + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "embed_vision.embedding." in name or "embed_audio.embedding." in name: + continue + if self.audio_tower is None and ( + "audio_tower." in name or "embed_audio." in name + ): + continue + + name = re.sub(r"^model\.", "", name) + + # HF has router.per_expert_scale and experts.* on the decoder layer; + # remap into our moe.* subtree since Gemma4MoE owns both. + name = name.replace(".router.per_expert_scale", ".moe.per_expert_scale") + if ".experts." in name and ".moe.experts." not in name: + name = name.replace(".experts.", ".moe.experts.") + + # Remap audio tower checkpoint names to our module tree + if "audio_tower." in name: + name = self._remap_audio_tower_name(name) + + # Remap vision / audio tower names (fused QKV/GateUp, clippable wrappers) + if "vision_tower." in name or "audio_tower." in name: + name = self._remap_tower_name(name, params_dict) + + # attention_k_eq_v: full-attention layers have no v_proj in the + # checkpoint (K and V share weights). When we see a k_proj weight + # for one of these layers, load it into both the "k" and "v" shards + # of the fused QKV so the forward produces v_raw == k_raw. + should_dup_k_to_v = ( + ".k_proj." in name + and k_eq_v_layers + and "language_model." in name + and (m := re.search(r"layers\.(\d+)\.", name)) is not None + and int(m.group(1)) in k_eq_v_layers + ) + + # MoE expert weights checked first (gate_up_proj contains "up_proj" + # which would false-match the stacked dense MLP mapping). + orig_name = name + for param_name, weight_name, shard_ids in expert_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + for i in range(num_experts): + chunks = loaded_weight[i].chunk(len(shard_ids), dim=0) + for chunk, sid in zip(chunks, shard_ids): + weight_loader(param, chunk, name, sid, i) + break + else: + for param_name, weight_name, shard_id in self.stacked_params_mapping: + name = orig_name + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + if should_dup_k_to_v: + weight_loader(param, loaded_weight, "v") + break + else: + name = orig_name + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + param_names = set(dict(self.named_parameters()).keys()) + buckets = { + logging.WARNING: ( + "Some weights are not initialized from checkpoints", + lambda p: p in param_names, + ), + logging.INFO: ( + "Persistent buffers not in checkpoint (using default init)", + lambda p: p not in param_names and p not in non_persistent_buffers, + ), + logging.DEBUG: ( + "Non-persistent buffers not in checkpoint (expected)", + lambda p: p in non_persistent_buffers, + ), + } + for level, (msg, pred) in buckets.items(): + names = sorted(p for p in unloaded_params if pred(p)) + if names: + logger.log(level, "%s: %s", msg, names) + return loaded_params + + lora_pattern = re.compile( + r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self.lora_pattern.match(module_name)) + + def get_hidden_dim(self, module_name, layer_idx): + # return input_dim, output_dim + if module_name == "qkv_proj": + return ( + self.config.hidden_size, + self.config.head_dim + * ( + self.config.num_attention_heads + + self.config.num_key_value_heads * 2 + ), + ) + elif module_name == "o_proj": + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name == "gate_up_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.hidden_size, self.config.intermediate_size[0] * 2 + elif module_name == "down_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.intermediate_size[0], self.config.hidden_size + else: + raise NotImplementedError() + + +EntryClass = Gemma4ForConditionalGeneration diff --git a/python/sglang/srt/models/gemma4_vision.py b/python/sglang/srt/models/gemma4_vision.py new file mode 100644 index 000000000000..f0c49cbc68b8 --- /dev/null +++ b/python/sglang/srt/models/gemma4_vision.py @@ -0,0 +1,599 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import Gemma4VisionConfig + +from sglang.srt.layers.attention.vision import QKV_BACKEND_IMPL +from sglang.srt.layers.clippable_linear import ( + ClippableGateUpParallelLinear, + ClippableQKVParallelLinear, + ClippableRowParallelLinear, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.layernorm import Gemma4RMSNorm +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import add_prefix, get_device_capability, is_cuda, is_hip + +# --------------------------------------------------------------------------- +# 2-D Multidimensional RoPE (matches HF Gemma4RotaryEmbedding for vision) +# --------------------------------------------------------------------------- + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: + return (x * cos) + (_rotate_half(x) * sin) + + +class Gemma4VisionRotaryEmbedding(nn.Module): + """Compute 2-D multidimensional RoPE cos/sin for patch positions.""" + + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.head_dim = config.head_dim + self.rope_theta: float = config.rope_parameters["rope_theta"] + + @torch.no_grad() + def forward( + self, x: torch.Tensor, patch_positions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: [batch, seq, hidden] – only used for device/dtype. + patch_positions: [batch, num_patches, 2] – (x, y) coordinates. + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim]. + """ + ndim = patch_positions.shape[-1] # 2 + head_dim_per_dim = self.head_dim // ndim + + all_embs = [] + for d in range(ndim): + dim_inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange( + 0, head_dim_per_dim, 2, device=x.device, dtype=torch.float + ) + / head_dim_per_dim + ) + ) + dim_inv_freq_expanded = dim_inv_freq[None, :, None].expand( + patch_positions.shape[0], -1, 1 + ) + dim_positions = patch_positions[:, :, d].float() + dim_positions_expanded = dim_positions[:, None, :] + + dim_freqs = (dim_inv_freq_expanded @ dim_positions_expanded).transpose(1, 2) + dim_emb = torch.cat((dim_freqs, dim_freqs), dim=-1) + all_embs.append(dim_emb) + + emb = torch.cat(all_embs, dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + return cos, sin + + +def _apply_multidimensional_rope( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """Apply 2-D RoPE to x of shape [batch*seq, heads, head_dim]. + + cos/sin have shape [batch, seq, head_dim]. We split along head_dim into + ndim=2 parts and apply standard rotary to each independently. + """ + ndim = 2 + chunk_size = x.shape[-1] // ndim + x_parts = x.split(chunk_size, dim=-1) + cos_parts = cos.split(chunk_size, dim=-1) + sin_parts = sin.split(chunk_size, dim=-1) + y_parts = [ + _apply_rotary(x_parts[k], cos_parts[k], sin_parts[k]) for k in range(ndim) + ] + return torch.cat(y_parts, dim=-1) + + +# --------------------------------------------------------------------------- +# Vision Attention (TP-sharded, fused QKV) +# --------------------------------------------------------------------------- + + +class Gemma4VisionAttention(nn.Module): + """Multi-head attention for the Gemma 4 vision encoder. + + QKV uses a fused ``ClippableQKVParallelLinear`` for efficient matmul with + per-projection clip bounds. Output projection uses ``ClippableLinear``. + """ + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.head_dim = config.head_dim + + tp_size = get_attention_tp_size() + self.num_heads_per_partition = config.num_attention_heads // tp_size + self.num_kv_heads_per_partition = config.num_key_value_heads // tp_size + + self.qkv = ClippableQKVParallelLinear( + hidden_size=config.hidden_size, + head_size=config.head_dim, + total_num_heads=config.num_attention_heads, + total_num_kv_heads=config.num_key_value_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=prefix, + ) + self.o_proj = ClippableRowParallelLinear( + input_size=config.num_attention_heads * config.head_dim, + output_size=config.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.q_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.v_norm = Gemma4RMSNorm( + self.head_dim, eps=config.rms_norm_eps, scale_shift=0.0, with_scale=False + ) + + backend = self._select_backend() + self.qkv_backend = QKV_BACKEND_IMPL[backend]( + head_dim=config.head_dim, + num_heads=self.num_heads_per_partition, + num_kv_heads=self.num_kv_heads_per_partition, + dropout=0.0, + flatten_batch=True, + softmax_in_single_precision=False, + softmax_scale=1.0, + ) + + @staticmethod + def _select_backend() -> str: + """Mirror VisionAttention._determine_attention_backend for consistency.""" + from sglang.srt.server_args import get_global_server_args + + override = get_global_server_args().mm_attention_backend + if override is not None: + return override + if is_cuda(): + major, _ = get_device_capability() + if major == 9: + from sglang.srt.utils import is_blackwell_supported + + if is_blackwell_supported(): + return "triton_attn" + return "fa3" + return "triton_attn" + if is_hip(): + # ROCm: use triton_attn to avoid SDPA flatten_batch issues + # with multi-image/video inputs + return "triton_attn" + return "sdpa" + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + bsz, seq_len, _ = hidden_states.shape + + q, k, v = self.qkv(hidden_states) + + q = q.reshape(bsz * seq_len, self.num_heads_per_partition, self.head_dim) + k = k.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + v = v.reshape(bsz * seq_len, self.num_kv_heads_per_partition, self.head_dim) + + q = self.q_norm(q.reshape(-1, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.head_dim)).reshape(k.shape) + v = self.v_norm(v.reshape(-1, self.head_dim)).reshape(v.shape) + + cos_flat = cos.reshape(bsz * seq_len, 1, self.head_dim) + sin_flat = sin.reshape(bsz * seq_len, 1, self.head_dim) + q = _apply_multidimensional_rope(q, cos_flat, sin_flat) + k = _apply_multidimensional_rope(k, cos_flat, sin_flat) + + if attention_mask is not None: + attn_mask_4d = ( + attention_mask.unsqueeze(-1) * attention_mask.unsqueeze(1) + ).unsqueeze(1) + else: + attn_mask_4d = None + + output = self.qkv_backend.forward( + q=q, + k=k, + v=v, + cu_seqlens=None, + bsz=bsz, + seq_len=seq_len, + attention_mask=attn_mask_4d, + softmax_scale=1.0, + ) + + output = rearrange(output, "(b s) h d -> b s (h d)", b=bsz) + output = self.o_proj(output) + return output + + +# --------------------------------------------------------------------------- +# Vision MLP (GatedGELU, TP-sharded) +# --------------------------------------------------------------------------- + + +class Gemma4VisionMLP(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if config.hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + f"Gemma4VisionMLP expects hidden_activation='gelu_pytorch_tanh', " + f"got {config.hidden_activation!r}" + ) + self.gate_up = ClippableGateUpParallelLinear( + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + quant_config=quant_config, + prefix=prefix, + ) + self.down_proj = ClippableRowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate, up = self.gate_up(x) + x = F.gelu(gate, approximate="tanh") * up + x = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# Encoder Layer +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoderLayer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Gemma4VisionAttention( + config, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = Gemma4VisionMLP( + config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + eps = config.rms_norm_eps + hs = config.hidden_size + self.input_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_attention_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.pre_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + self.post_feedforward_layernorm = Gemma4RMSNorm(hs, eps=eps) + + self.register_buffer("layer_scalar", torch.ones(())) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, cos, sin, attention_mask) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = hidden_states * self.layer_scalar + return hidden_states + + +# --------------------------------------------------------------------------- +# Vision Transformer (stack of encoder layers + RoPE) +# --------------------------------------------------------------------------- + + +class Gemma4VisionTransformer(nn.Module): + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.rotary_emb = Gemma4VisionRotaryEmbedding(config) + self.layers = nn.ModuleList( + [ + Gemma4VisionEncoderLayer( + config, + layer_idx=i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + patch_positions: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + inputs_embeds: [batch, seq, hidden_size] + attention_mask: [batch, seq] — True = valid token + patch_positions: [batch, seq, 2] + Returns: + last_hidden_state: [batch, seq, hidden_size] + """ + cos, sin = self.rotary_emb(inputs_embeds, patch_positions) + hidden_states = inputs_embeds + for layer in self.layers: + hidden_states = layer(hidden_states, cos, sin, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# Patch Embedder +# --------------------------------------------------------------------------- + + +class Gemma4VisionPatchEmbedder(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.position_embedding_size = config.position_embedding_size + + self.input_proj = nn.Linear( + 3 * self.patch_size**2, self.hidden_size, bias=False + ) + self.position_embedding_table = nn.Parameter( + torch.ones(2, self.position_embedding_size, self.hidden_size) + ) + + def _position_embeddings( + self, patch_positions: torch.Tensor, padding_positions: torch.Tensor + ) -> torch.Tensor: + clamped_positions = patch_positions.clamp(min=0) + one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size) + one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table) + position_embeddings = one_hot @ self.position_embedding_table + position_embeddings = position_embeddings.sum(dim=1) + position_embeddings = torch.where( + padding_positions.unsqueeze(-1), 0.0, position_embeddings + ) + return position_embeddings + + def _patch_projection(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Project pre-patchified pixels into model space. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — already patchified + by the image processor, values in [0, 1]. + """ + patches = 2 * (pixel_values - 0.5) + return self.input_proj(patches.to(self.input_proj.weight.dtype)) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + padding_positions: torch.Tensor, + ) -> torch.Tensor: + """Compute patch embeddings with positional information. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + padding_positions: [batch, num_patches] — True for padding patches. + """ + hidden_states = self._patch_projection(pixel_values) + position_embeddings = self._position_embeddings( + pixel_position_ids, padding_positions + ) + return hidden_states + position_embeddings + + +# --------------------------------------------------------------------------- +# Pooler +# --------------------------------------------------------------------------- + + +class Gemma4VisionPooler(nn.Module): + def __init__(self, config: Gemma4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.root_hidden_size = self.hidden_size**0.5 + + def _avg_pool_by_positions( + self, x: torch.Tensor, patch_positions: torch.Tensor, length: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + input_seq_len = x.shape[1] + k = int((input_seq_len // length) ** 0.5) + k_squared = k**2 + if k_squared * length != input_seq_len: + raise ValueError( + f"Cannot pool {x.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}." + ) + clamped_positions = patch_positions.clamp(min=0) + max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") + kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] + + weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared + output = weights.transpose(1, 2).to(x.dtype) @ x + mask = torch.logical_not((weights == 0).all(dim=1)) + return output, mask + + def forward( + self, + hidden_states: torch.Tensor, + patch_positions: torch.Tensor, + padding_positions: torch.Tensor, + output_length: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + (pooled_hidden_states, mask) where mask is True for valid tokens. + """ + if output_length is None: + raise ValueError("output_length is required for Gemma4VisionPooler") + if output_length > hidden_states.shape[1]: + raise ValueError( + f"Cannot output more soft tokens (requested {output_length}) than there are patches" + f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing." + ) + length = output_length + if isinstance(length, (list, tuple)): + length = length[0] + if hidden_states.shape[1] == length: + mask = padding_positions + else: + hidden_states, mask = self._avg_pool_by_positions( + hidden_states, patch_positions, length + ) + hidden_states = hidden_states * self.root_hidden_size + return hidden_states, mask + + +# --------------------------------------------------------------------------- +# Top-level Vision Encoder (patch_embedder → transformer → pooler) +# --------------------------------------------------------------------------- + + +class Gemma4VisionEncoder(nn.Module): + """Drop-in replacement for HF ``Gemma4VisionEncoder`` with TP support.""" + + def __init__( + self, + config: Gemma4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.pooling_kernel_size = config.pooling_kernel_size + + self.patch_embedder = Gemma4VisionPatchEmbedder(config) + self.encoder = Gemma4VisionTransformer( + config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + self.pooler = Gemma4VisionPooler(config) + + # Post-pooling standardization (normalizes vision tokens before projection) + self.standardize = getattr(config, "standardize", False) + if self.standardize: + self.register_buffer("std_bias", torch.zeros(config.hidden_size)) + self.register_buffer("std_scale", torch.ones(config.hidden_size)) + + @property + def device(self) -> torch.device: + return self.patch_embedder.input_proj.weight.device + + def forward( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode pre-patchified pixel_values into soft tokens. + + Args: + pixel_values: [batch, num_patches, patch_pixels] — pre-patchified + by the image processor. + pixel_position_ids: [batch, num_patches, 2] — (x, y) positions, + -1 for padding patches. + + Returns: + (hidden_states, pooler_mask) — hidden_states [batch, output_len, hidden], + pooler_mask [batch, output_len] True = valid. + """ + k2 = self.pooling_kernel_size * self.pooling_kernel_size + output_length = pixel_values.shape[-2] // k2 + + padding_positions = (pixel_position_ids == -1).all(dim=-1) + + inputs_embeds = self.patch_embedder( + pixel_values, pixel_position_ids, padding_positions + ) + + last_hidden = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, + patch_positions=pixel_position_ids, + ) + + pooled, pooler_mask = self.pooler( + last_hidden, + pixel_position_ids, + padding_positions, + output_length=output_length, + ) + + if self.standardize: + pooled = (pooled - self.std_bias) * self.std_scale + + return pooled, pooler_mask diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 84cfa46f1c77..55659586d4bb 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -577,7 +577,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name == "model.embed_tokens.weight": - if self.pp_group.is_last_rank and self.config.tie_word_embeddings: + if ( + not hasattr(self, "pp_group") or self.pp_group.is_last_rank + ) and self.config.tie_word_embeddings: if "lm_head.weight" in params_dict: param = params_dict["lm_head.weight"] weight_loader = getattr( diff --git a/python/sglang/srt/models/qwen3_asr.py b/python/sglang/srt/models/qwen3_asr.py new file mode 100644 index 000000000000..9c86818b6256 --- /dev/null +++ b/python/sglang/srt/models/qwen3_asr.py @@ -0,0 +1,199 @@ +"""Qwen3-ASR model compatible with HuggingFace weights""" + +import logging +from typing import Any, Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from sglang.srt.configs.qwen3_asr import Qwen3ASRConfig +from sglang.srt.configs.qwen3_omni import Qwen3OmniMoeAudioEncoderConfig +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen3 import Qwen3ForCausalLM +from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeAudioEncoder +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class Qwen3ASRForConditionalGeneration(nn.Module): + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen3ASRConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + thinker_config = config.thinker_config + + if getattr(thinker_config, "audio_config", None) is None: + thinker_config.audio_config = Qwen3OmniMoeAudioEncoderConfig() + + self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + self.language_model = Qwen3ForCausalLM( + thinker_config.text_config, + quant_config, + prefix=add_prefix("language_model", prefix), + ) + self.pattern = MultiModalityDataPaddingPatternMultimodalTokens() + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + return self.pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + device = next(self.audio_tower.parameters()).device + + input_features = ( + torch.cat([item.feature for item in items]) + .type(self.audio_tower.dtype) + .to(device) + ) + + has_mask = all( + getattr(item, "feature_attention_mask", None) is not None for item in items + ) + + if has_mask: + feature_attention_mask = ( + torch.cat([item.feature_attention_mask for item in items], dim=0) + .type(torch.long) + .to(device) + ) + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + audio_feature_lengths = torch.tensor( + [input_features.shape[-1]] * input_features.shape[0], + dtype=torch.long, + device=device, + ) + input_features = input_features.permute(0, 2, 1).reshape( + -1, input_features.shape[1] + ) + + audio_outputs = self.audio_tower( + input_features, + feature_lens=audio_feature_lengths, + ) + return audio_outputs.last_hidden_state + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: Any, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.AUDIO: self.get_audio_feature, + }, + positions=positions, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + llm_stacked_params = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Audio tower has separate q/k/v in checkpoint → stack into qkv_proj + audio_stacked_params = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + continue + + if ( + getattr( + self.config.thinker_config.text_config, "tie_word_embeddings", False + ) + and "lm_head.weight" in name + ): + continue + + if "talker" in name or "code2wav" in name: + continue + + if name.startswith("thinker.audio_tower."): + name = name.replace("thinker.audio_tower.", "audio_tower.", 1) + elif name.startswith("thinker.lm_head."): + name = name.replace("thinker.lm_head.", "language_model.lm_head.", 1) + elif name.startswith("thinker.model."): + name = name.replace("thinker.model.", "language_model.model.", 1) + + is_audio = "audio_tower" in name + + # Audio tower: remap out_proj → proj for VisionAttention + if is_audio and "out_proj" in name: + name = name.replace("out_proj", "proj") + + stacked_params = audio_stacked_params if is_audio else llm_stacked_params + + for param_name, weight_name, shard_id in stacked_params: + if weight_name not in name: + continue + name_tmp = name.replace(weight_name, param_name) + if name_tmp.endswith(".bias") and name_tmp not in params_dict: + continue + if name_tmp not in params_dict: + continue + param = params_dict[name_tmp] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen3ASRForConditionalGeneration diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 839d5b74e079..809c0a0f706f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -386,8 +386,10 @@ def process_mm_data( if audios: if self._processor.__class__.__name__ in { "Gemma3nProcessor", + "Gemma4Processor", "GlmAsrProcessor", "Qwen2AudioProcessor", + "Qwen3ASRProcessor", "Qwen3OmniMoeProcessor", }: # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 diff --git a/python/sglang/srt/multimodal/processors/gemma4.py b/python/sglang/srt/multimodal/processors/gemma4.py new file mode 100644 index 000000000000..80bb37061358 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/gemma4.py @@ -0,0 +1,145 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch + +from sglang.srt.managers.multimodal_processor import ( + BaseMultimodalProcessor as SGLangBaseProcessor, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalProcessorOutput +from sglang.srt.models.gemma4_audio import _SSCP_CONV_STRIDE_SIZES +from sglang.srt.models.gemma4_mm import Gemma4ForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens +from sglang.srt.utils.video_decoder import VideoDecoderWrapper + + +class Gemma4SGLangProcessor(SGLangBaseProcessor): + """Multimodal processor for Gemma4 supporting image, video, and audio inputs.""" + + models = [Gemma4ForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + + self.IM_START_TOKEN_ID = hf_config.boi_token_id + self.IM_END_TOKEN_ID = hf_config.eoi_token_id + + self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id + self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + audio_token_id=hf_config.audio_token_id, + ).build(_processor) + + # Register image-processor and video-processor outputs so they are stored on + # MultimodalDataItem via collect_mm_items_from_processor_output. + self.ATTR_NAME_TO_MODALITY["image_position_ids"] = Modality.IMAGE + self.ATTR_NAME_TO_MODALITY["video_position_ids"] = Modality.VIDEO + + def _get_audio_pad_multiple(self) -> int: + """Derive the waveform padding alignment from processor config. + + The HF processor's ceil(duration_ms / audio_ms_per_token) formula can + overshoot by 1 token relative to what the SSCP convolutions produce. + Padding waveforms to a multiple of (hop_length * first_conv_stride) + aligns the two calculations. + See: gemma-4-eap-extras/examples/gemma-4-audio-examples.ipynb + """ + fe = getattr(self._processor, "feature_extractor", None) + hop = getattr(fe, "hop_length", 160) + first_stride = _SSCP_CONV_STRIDE_SIZES[0][0] + return hop * first_stride + + def _video_decoder_to_tensor(self, vdw: VideoDecoderWrapper) -> torch.Tensor: + """Convert a VideoDecoderWrapper to a (sampled_frames, C, H, W) uint8 tensor. + + SGLang's load_video returns VideoDecoderWrapper which the HF + Gemma4VideoProcessor does not recognise (expects torch.Tensor or + np.ndarray). We replicate HF's uniform frame sampling here to + avoid materialising the entire video in memory, then delegate the + rest (resize, patchify, position IDs) to the HF video processor. + """ + total = len(vdw) + num_frames = getattr( + getattr(self._processor, "video_processor", None), + "num_frames", + 32, + ) + if total <= num_frames: + indices = list(range(total)) + else: + indices = torch.arange(0, total, total / num_frames).int().tolist() + frames_np = vdw.get_frames_at(indices) # (N, H, W, C) + return torch.from_numpy(frames_np).permute(0, 3, 1, 2).contiguous() + + def process_mm_data( + self, input_text, images=None, videos=None, audios=None, **kwargs + ): + if audios: + pad_multiple = self._get_audio_pad_multiple() + padded = [] + for a in audios: + a = np.asarray(a) + remainder = len(a) % pad_multiple + if remainder != 0: + a = np.pad(a, (0, pad_multiple - remainder), mode="constant") + padded.append(a) + audios = padded + if videos: + videos = [ + ( + self._video_decoder_to_tensor(v) + if isinstance(v, VideoDecoderWrapper) + else v + ) + for v in videos + ] + kwargs.setdefault("do_sample_frames", False) + return super().process_mm_data( + input_text, images=images, videos=videos, audios=audios, **kwargs + ) + + async def process_mm_data_async( + self, + image_data: Optional[List[Union[str, bytes, Dict]]] = None, + audio_data: Optional[List[Union[str, bytes, Dict]]] = None, + input_text: str = "", + request_obj=None, + *args, + **kwargs, + ): + """Process multimodal data including images, video, and audio.""" + base_output = self.load_mm_data( + prompt=input_text, + image_data=image_data, + video_data=request_obj.video_data if request_obj else None, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + return MultimodalProcessorOutput( + input_ids=input_ids.tolist(), + mm_items=mm_items, + im_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, + audio_token_id=self.mm_tokens.audio_token_id, + ) diff --git a/python/sglang/srt/multimodal/processors/qwen3_asr.py b/python/sglang/srt/multimodal/processors/qwen3_asr.py new file mode 100644 index 000000000000..59ebb921ea99 --- /dev/null +++ b/python/sglang/srt/multimodal/processors/qwen3_asr.py @@ -0,0 +1,95 @@ +import re +from typing import Union + +import torch + +from sglang.srt.managers.schedule_batch import Modality, MultimodalProcessorOutput +from sglang.srt.models.qwen3_asr import Qwen3ASRForConditionalGeneration +from sglang.srt.multimodal.processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) + +_DEFAULT_ASR_PROMPT = ( + "<|im_start|>user\n" + "<|audio_start|><|audio_pad|><|audio_end|>" + "<|im_end|>\n" + "<|im_start|>assistant\n" +) + + +class Qwen3ASRMultimodalProcessor(BaseMultimodalProcessor): + models = [Qwen3ASRForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + super().__init__(hf_config, server_args, _processor, *args, **kwargs) + self.AUDIO_TOKEN = "<|audio_start|><|audio_pad|><|audio_end|>" + self.AUDIO_TOKEN_REGEX = re.compile( + r"<\|audio_start\|>(?:<\|audio_pad\|>)+<\|audio_end\|>" + ) + tokenizer = self._processor.tokenizer + self.audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>") + self.audio_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>") + self.audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>") + + self.mm_tokens = MultimodalSpecialTokens( + audio_token=self.AUDIO_TOKEN, + audio_token_regex=self.AUDIO_TOKEN_REGEX, + audio_token_id=self.audio_token_id, + ).build(_processor) + + self.ATTR_NAME_TO_MODALITY.update({"feature_attention_mask": Modality.AUDIO}) + + def _build_transcription_prompt(self, input_text: Union[str, list]) -> str: + if isinstance(input_text, list): + input_text = self._tokenizer.decode(input_text) + if not input_text or not input_text.strip(): + return _DEFAULT_ASR_PROMPT + return input_text + + def compute_mrope_positions(self, input_ids, mm_items): + if isinstance(input_ids, list): + seq_len = len(input_ids) + else: + seq_len = input_ids.shape[-1] if input_ids.dim() > 1 else input_ids.shape[0] + positions = torch.arange(seq_len, dtype=torch.long) + mrope_positions = positions.unsqueeze(0).expand(3, -1).clone() + return mrope_positions, torch.tensor([0], dtype=torch.long) + + async def process_mm_data_async( + self, + audio_data=None, + input_text=None, + request_obj=None, + **kwargs, + ): + if not audio_data: + return None + + prompt = self._build_transcription_prompt(input_text) + + base_output = self.load_mm_data( + prompt=prompt, + audio_data=audio_data, + multimodal_tokens=self.mm_tokens, + ) + if base_output is None: + return None + + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) + + mrope_positions, mrope_position_delta = self.compute_mrope_positions( + input_ids, mm_items + ) + + return MultimodalProcessorOutput( + mm_items=mm_items, + input_ids=input_ids.tolist(), + audio_start_id=self.audio_start_id, + audio_token_id=self.audio_token_id, + audio_end_id=self.audio_end_id, + mrope_positions=mrope_positions, + mrope_position_delta=mrope_position_delta, + ) diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index c3dbb3116464..8811c90b2ddc 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -37,6 +37,7 @@ def __init__( self._buffer = "" self.stripped_think_start = False + self.think_start_self_label = "" self.continue_final_message = continue_final_message if self.continue_final_message: @@ -62,7 +63,9 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: return StreamingParseResult(normal_text=text) # The text is considered to be in a reasoning block. - processed_text = text.replace(self.think_start_token, "").strip() + processed_text = text.replace( + self.think_start_token + self.think_start_self_label, "" + ).strip() if ( self.think_end_token not in processed_text @@ -111,8 +114,10 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: self._buffer += new_text current_text = self._buffer + think_start_text = self.think_start_token + self.think_start_self_label + # If the current text is a prefix of the think token, keep buffering - tokens_to_check = [self.think_start_token, self.think_end_token] + tokens_to_check = [think_start_text, self.think_end_token] if self.tool_start_token: tokens_to_check.append(self.tool_start_token) if any( @@ -122,8 +127,8 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: return StreamingParseResult() # Strip `` token if present - if not self.stripped_think_start and self.think_start_token in current_text: - current_text = current_text.replace(self.think_start_token, "") + if not self.stripped_think_start and think_start_text in current_text: + current_text = current_text.replace(think_start_text, "", 1) self.stripped_think_start = True self._in_reasoning = True @@ -477,6 +482,27 @@ def __init__( ) +class Gemma4Detector(BaseReasoningFormatDetector): + """Gemma4 reasoning detector.""" + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "<|channel>", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + self.think_start_self_label = "thought\n" + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -505,6 +531,7 @@ class ReasoningParser: "mistral": MistralDetector, "nemotron_3": Nemotron3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f7e0d8ee99f5..74445c9cd1f2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,7 @@ import tempfile from typing import Any, Callable, Dict, List, Literal, Optional, Union +from sglang.srt.configs.linear_attn_model_registry import get_linear_attn_spec_by_arch from sglang.srt.connector import ConnectorType from sglang.srt.environ import envs from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -1501,6 +1502,14 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_model_config().hf_config model_arch = hf_config.architectures[0] + _hybrid_spec = get_linear_attn_spec_by_arch(model_arch) + if _hybrid_spec is not None: + self._handle_mamba_radix_cache( + model_arch=model_arch, + support_mamba_cache=_hybrid_spec.support_mamba_cache, + support_mamba_cache_extra_buffer=_hybrid_spec.support_mamba_cache_extra_buffer, + ) + if model_arch in [ "MistralLarge3ForCausalLM", "PixtralForConditionalGeneration", @@ -1878,6 +1887,10 @@ def _handle_model_specific_adjustments(self): f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." ) self.disable_hybrid_swa_memory = True + elif model_arch == "Gemma4ForConditionalGeneration": + if self.is_attention_backend_not_set(): + self.attention_backend = "triton" + logger.info("Use triton as default attention backend for Gemma4") elif model_arch in ["Exaone4ForCausalLM", "ExaoneMoEForCausalLM"]: if hf_config.sliding_window_pattern is not None: logger.warning( diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 22a530c7c7a8..b928b08d4d79 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -622,6 +622,32 @@ def get_config( if config.model_type == "multi_modality": config.update({"architectures": ["MultiModalityCausalLM"]}) + if config.model_type == "gemma4": + # Gemma4 configs use base attributes for SWA layers and `global_*` + # variants for full-attention layers. SGLang expects the opposite: + # base = full-attention, `swa_*` = sliding-window overrides. + # Remap here so the rest of the stack sees a uniform convention. + text_config = config.text_config + global_head_dim = getattr(text_config, "global_head_dim", None) + global_kv_heads = getattr(text_config, "num_global_key_value_heads", None) + + swa_head_dim = text_config.head_dim + swa_kv_heads = text_config.num_key_value_heads + + text_config.swa_head_dim = swa_head_dim + text_config.swa_v_head_dim = swa_head_dim + text_config.swa_num_key_value_heads = swa_kv_heads + + if global_head_dim is not None: + text_config.head_dim = global_head_dim + if global_kv_heads is not None: + text_config.num_key_value_heads = global_kv_heads + + if not hasattr(text_config, "v_head_dim"): + text_config.v_head_dim = text_config.head_dim + if not hasattr(text_config, "swa_v_head_dim"): + text_config.swa_v_head_dim = text_config.swa_head_dim + if config.model_type == "longcat_flash": config.update({"architectures": ["LongcatFlashForCausalLM"]}) diff --git a/test/registered/8-gpu-models/test_ring_2_5_1t.py b/test/registered/8-gpu-models/test_ring_2_5_1t.py index 71b2a4f2609e..29c64160cf96 100644 --- a/test/registered/8-gpu-models/test_ring_2_5_1t.py +++ b/test/registered/8-gpu-models/test_ring_2_5_1t.py @@ -5,8 +5,7 @@ from sglang.test.run_combined_tests import run_combined_tests from sglang.test.test_utils import ModelLaunchSettings -# register_cuda_ci(est_time=1000, suite="nightly-8-gpu-common", nightly=True) -register_cuda_ci(est_time=1000, suite="stage-c-test-8-gpu-h200") +register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) RING_2_5_1T_MODEL_PATH = "inclusionAI/Ring-2.5-1T" @@ -25,6 +24,8 @@ def test_ring_2_5_1t(self): '{"enable_multithread_load": true, "num_threads": 64}', "--watchdog-timeout", "1800", + "--soft-watchdog-timeout", + "1800", ] variants = [ diff --git a/test/registered/amd/accuracy/mi30x/test_qwen35_eval_amd.py b/test/registered/amd/accuracy/mi30x/test_qwen35_eval_amd.py index dae0e31c10f7..112630ed474c 100644 --- a/test/registered/amd/accuracy/mi30x/test_qwen35_eval_amd.py +++ b/test/registered/amd/accuracy/mi30x/test_qwen35_eval_amd.py @@ -8,6 +8,10 @@ import os import unittest +from pathlib import Path + +import numpy as np +import yaml from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci @@ -15,7 +19,9 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, + write_github_step_summary, ) register_amd_ci(est_time=3600, suite="nightly-amd-accuracy-8-gpu-qwen35", nightly=True) @@ -38,7 +44,7 @@ def setUpClass(cls): "--tp", str(TP_SIZE), "--attention-backend", - "triton", + "aiter", "--trust-remote-code", "--model-loader-extra-config", '{"enable_multithread_load": true}', @@ -59,6 +65,41 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def test_lm_eval(self): + """Override to write accuracy results to GitHub step summary.""" + import requests + + requests.get(self.base_url + "/flush_cache") + + eval_config = yaml.safe_load( + Path(self.model_config_name).read_text(encoding="utf-8") + ) + results = self.launch_lm_eval(eval_config) + rtol = eval_config.get("rtol", self.default_rtol) + model_name = eval_config.get("model_name", self.model) + + success = True + summary = f"### lm-eval accuracy ({model_name})\n" + summary += "| task | metric | expected | measured | status |\n" + summary += "| ---- | ------ | -------- | -------- | ------ |\n" + for task in eval_config["tasks"]: + for metric in task["metrics"]: + expected = metric["value"] + measured = results["results"][task["name"]][metric["name"]] + passed = bool(np.isclose(expected, measured, rtol=rtol)) + status = "✅" if passed else "❌" + summary += f"| {task['name']} | {metric['name']} | {expected:.4f} | {measured:.4f} | {status} |\n" + print( + f"{task['name']} | {metric['name']}: " + f"expected={expected:.3f} | measured={measured:.3f} | rtol={rtol}" + ) + success = success and passed + + if is_in_ci(): + write_github_step_summary(summary) + + self.assertTrue(success, "lm-eval validation failed") + if __name__ == "__main__": unittest.main() diff --git a/test/registered/amd/accuracy/mi35x/test_qwen35_eval_mi35x.py b/test/registered/amd/accuracy/mi35x/test_qwen35_eval_mi35x.py index 2c6b8059bfa8..4b35a28d4405 100644 --- a/test/registered/amd/accuracy/mi35x/test_qwen35_eval_mi35x.py +++ b/test/registered/amd/accuracy/mi35x/test_qwen35_eval_mi35x.py @@ -8,8 +8,11 @@ import os import unittest +from pathlib import Path +import numpy as np import requests +import yaml from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_amd_ci @@ -17,7 +20,9 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, + write_github_step_summary, ) register_amd_ci( @@ -40,12 +45,12 @@ def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST def test_lm_eval(self): - """Override to handle server lifecycle within test method (MI35x pattern).""" + """Override to handle server lifecycle and write results to summary.""" other_args = [ "--tp", str(TP_SIZE), "--attention-backend", - "triton", + "aiter", "--trust-remote-code", "--model-loader-extra-config", '{"enable_multithread_load": true}', @@ -65,7 +70,35 @@ def test_lm_eval(self): try: requests.get(self.base_url + "/flush_cache") - super().test_lm_eval() + + eval_config = yaml.safe_load( + Path(self.model_config_name).read_text(encoding="utf-8") + ) + results = self.launch_lm_eval(eval_config) + rtol = eval_config.get("rtol", self.default_rtol) + model_name = eval_config.get("model_name", self.model) + + success = True + summary = f"### lm-eval accuracy ({model_name})\n" + summary += "| task | metric | expected | measured | status |\n" + summary += "| ---- | ------ | -------- | -------- | ------ |\n" + for task in eval_config["tasks"]: + for metric in task["metrics"]: + expected = metric["value"] + measured = results["results"][task["name"]][metric["name"]] + passed = bool(np.isclose(expected, measured, rtol=rtol)) + status = "✅" if passed else "❌" + summary += f"| {task['name']} | {metric['name']} | {expected:.4f} | {measured:.4f} | {status} |\n" + print( + f"{task['name']} | {metric['name']}: " + f"expected={expected:.3f} | measured={measured:.3f} | rtol={rtol}" + ) + success = success and passed + + if is_in_ci(): + write_github_step_summary(summary) + + self.assertTrue(success, "lm-eval validation failed") finally: kill_process_tree(process.pid) diff --git a/test/registered/amd/perf/mi30x/test_qwen35_fp8_perf_amd.py b/test/registered/amd/perf/mi30x/test_qwen35_fp8_perf_amd.py new file mode 100644 index 000000000000..be5314a6438a --- /dev/null +++ b/test/registered/amd/perf/mi30x/test_qwen35_fp8_perf_amd.py @@ -0,0 +1,139 @@ +"""Nightly performance benchmark for Qwen3.5-397B-A17B FP8. + +Tests Qwen3.5-397B-A17B-FP8 (MoE, Hybrid Attention with Gated Delta Networks) +on 8 GPUs with triton attention backend. + +Model path can be configured via environment variable: +- QWEN35_FP8_MODEL_PATH: Path to Qwen3.5-FP8 model + (default: Qwen/Qwen3.5-397B-A17B-FP8) + +Example usage: + python -m pytest test_qwen35_fp8_perf_amd.py -v +""" + +import os +import unittest +from typing import List + +from sglang.test.ci.ci_register import register_amd_ci +from sglang.test.nightly_bench_utils import BenchmarkResult +from sglang.test.nightly_utils import NightlyBenchmarkRunner +from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, _parse_int_list_env + +register_amd_ci(est_time=5400, suite="nightly-perf-8-gpu-qwen35-fp8", nightly=True) + + +def generate_simple_markdown_report(results: List[BenchmarkResult]) -> str: + """Generate a simplified markdown report without traces and cost columns. + + Skips the first result if it's a warmup run (duplicate batch_size). + """ + model_header = results[0].model_path + if results[0].run_name and results[0].run_name != "default": + model_header += f" ({results[0].run_name})" + + gpu_config = os.getenv("GPU_CONFIG", "MI325") + if gpu_config: + model_header += f" [{gpu_config}]" + + summary = f"### {model_header}\n" + summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | ITL (ms) |\n" + summary += "| ---------- | --------- | ----------- | ------------------------ | ------------------------- | -------- |\n" + + report_results = ( + results[1:] + if len(results) > 1 and results[0].batch_size == results[1].batch_size + else results + ) + + for result in report_results: + itl = 1 / (result.output_throughput / result.batch_size) * 1000 + summary += f"| {result.batch_size} | {result.input_len} | {result.latency:.2f} | {result.input_throughput:.2f} | {result.output_throughput:.2f} | {itl:.2f} |\n" + + return summary + + +QWEN35_FP8_MODEL_PATH = os.environ.get( + "QWEN35_FP8_MODEL_PATH", "Qwen/Qwen3.5-397B-A17B-FP8" +) +PROFILE_DIR = "performance_profiles_qwen35_fp8" + + +class TestNightlyQwen35Fp8Performance(unittest.TestCase): + """Nightly performance benchmark for Qwen3.5-397B-A17B FP8. + + Tests Qwen3.5 FP8 with triton attention backend on TP=8. + Runtime: ~90 minutes + """ + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.batch_sizes = [1, 8, 16, 64] + cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096")) + cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512")) + + cls.model_config = { + "name": "qwen35-fp8", + "model_path": QWEN35_FP8_MODEL_PATH, + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--attention-backend", + "aiter", + "--mem-fraction-static", + "0.8", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + "--watchdog-timeout", + "1200", + ], + "env_vars": { + "SGLANG_USE_AITER": "1", + }, + } + + cls.runner = NightlyBenchmarkRunner(PROFILE_DIR, cls.__name__, cls.base_url) + cls.runner.setup_profile_directory() + cls.runner.full_report = f"## {cls.__name__}\n" + + def test_bench_qwen35_fp8(self): + """Run benchmark for Qwen3.5-397B-A17B FP8.""" + old_env = {} + for key, value in self.model_config.get("env_vars", {}).items(): + old_env[key] = os.environ.get(key) + os.environ[key] = value + + try: + result_tuple = self.runner.run_benchmark_for_model( + model_path=self.model_config["model_path"], + batch_sizes=self.batch_sizes, + input_lens=self.input_lens, + output_lens=self.output_lens, + other_args=self.model_config["other_args"], + variant=self.model_config["name"], + extra_bench_args=["--trust-remote-code"], + enable_profile=False, + timeout=5400, + ) + results = result_tuple[0] + success = result_tuple[1] + + if results: + self.runner.full_report += ( + generate_simple_markdown_report(results) + "\n" + ) + + self.assertTrue(success, f"Benchmark failed for {QWEN35_FP8_MODEL_PATH}") + finally: + for key, value in old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + self.runner.write_final_report() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/amd/perf/mi35x/test_qwen35_fp8_perf_mi35x.py b/test/registered/amd/perf/mi35x/test_qwen35_fp8_perf_mi35x.py new file mode 100644 index 000000000000..6446eb601e84 --- /dev/null +++ b/test/registered/amd/perf/mi35x/test_qwen35_fp8_perf_mi35x.py @@ -0,0 +1,139 @@ +"""MI35x Nightly performance benchmark for Qwen3.5-397B-A17B FP8. + +Tests Qwen3.5-397B-A17B-FP8 (MoE, Hybrid Attention with Gated Delta Networks) +on 8 GPUs with triton attention backend. + +Registry: nightly-perf-8-gpu-mi35x-qwen35-fp8 suite +""" + +import os + +os.environ.setdefault("HF_HOME", "/data2/models/huggingface") +os.environ.setdefault("HF_HUB_CACHE", "/data2/models/huggingface/hub") + +import unittest +from typing import List + +from sglang.test.ci.ci_register import register_amd_ci +from sglang.test.nightly_bench_utils import BenchmarkResult +from sglang.test.nightly_utils import NightlyBenchmarkRunner +from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, _parse_int_list_env + +register_amd_ci( + est_time=5400, suite="nightly-perf-8-gpu-mi35x-qwen35-fp8", nightly=True +) + + +def generate_simple_markdown_report(results: List[BenchmarkResult]) -> str: + """Generate a simplified markdown report without traces and cost columns. + + Skips the first result if it's a warmup run (duplicate batch_size). + """ + model_header = results[0].model_path + if results[0].run_name and results[0].run_name != "default": + model_header += f" ({results[0].run_name})" + + gpu_config = os.getenv("GPU_CONFIG", "MI35x") + if gpu_config: + model_header += f" [{gpu_config}]" + + summary = f"### {model_header}\n" + summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | ITL (ms) |\n" + summary += "| ---------- | --------- | ----------- | ------------------------ | ------------------------- | -------- |\n" + + report_results = ( + results[1:] + if len(results) > 1 and results[0].batch_size == results[1].batch_size + else results + ) + + for result in report_results: + itl = 1 / (result.output_throughput / result.batch_size) * 1000 + summary += f"| {result.batch_size} | {result.input_len} | {result.latency:.2f} | {result.input_throughput:.2f} | {result.output_throughput:.2f} | {itl:.2f} |\n" + + return summary + + +QWEN35_FP8_MODEL_PATH = os.environ.get( + "QWEN35_FP8_MODEL_PATH", "Qwen/Qwen3.5-397B-A17B-FP8" +) +PROFILE_DIR = "performance_profiles_qwen35_fp8_mi35x" + + +class TestQwen35Fp8PerfMI35x(unittest.TestCase): + """Test suite for Qwen3.5-397B-A17B FP8 performance benchmarks on MI35x.""" + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.batch_sizes = [1, 8, 16, 64] + cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096")) + cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512")) + + cls.model_config = { + "name": "qwen35-fp8-mi35x", + "model_path": QWEN35_FP8_MODEL_PATH, + "other_args": [ + "--trust-remote-code", + "--tp", + "8", + "--attention-backend", + "aiter", + "--mem-fraction-static", + "0.8", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + "--watchdog-timeout", + "1200", + ], + "env_vars": { + "SGLANG_USE_AITER": "1", + }, + } + + cls.runner = NightlyBenchmarkRunner(PROFILE_DIR, cls.__name__, cls.base_url) + cls.runner.setup_profile_directory() + cls.runner.full_report = f"## {cls.__name__}\n" + + def test_qwen35_fp8_perf(self): + """Run Qwen3.5-397B-A17B FP8 performance benchmark on MI35x.""" + old_env = {} + for key, value in self.model_config.get("env_vars", {}).items(): + old_env[key] = os.environ.get(key) + os.environ[key] = value + + try: + result_tuple = self.runner.run_benchmark_for_model( + model_path=self.model_config["model_path"], + batch_sizes=self.batch_sizes, + input_lens=self.input_lens, + output_lens=self.output_lens, + other_args=self.model_config["other_args"], + variant=self.model_config["name"], + extra_bench_args=["--trust-remote-code"], + enable_profile=False, + timeout=5400, + ) + results = result_tuple[0] + success = result_tuple[1] + + if results: + self.runner.full_report += ( + generate_simple_markdown_report(results) + "\n" + ) + + self.assertTrue( + success, + f"Benchmark failed for {QWEN35_FP8_MODEL_PATH} on MI35x", + ) + finally: + for key, value in old_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + self.runner.write_final_report() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/ascend/basic_function/parameter/test_npu_retract_decode.py b/test/registered/ascend/basic_function/parameter/test_npu_retract_decode.py index 60c7231ff897..d6f7b56ee0da 100644 --- a/test/registered/ascend/basic_function/parameter/test_npu_retract_decode.py +++ b/test/registered/ascend/basic_function/parameter/test_npu_retract_decode.py @@ -54,7 +54,7 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=64, + num_examples=256, num_threads=32, ) diff --git a/test/registered/ascend/basic_function/parameter/test_npu_start_profile.py b/test/registered/ascend/basic_function/parameter/test_npu_start_profile.py index dc6addba3945..db3885f5c0cb 100644 --- a/test/registered/ascend/basic_function/parameter/test_npu_start_profile.py +++ b/test/registered/ascend/basic_function/parameter/test_npu_start_profile.py @@ -34,6 +34,7 @@ class TestStartProfile(CustomTestCase): @classmethod def setUpClass(cls): + os.makedirs(OUTPUT_DIR, exist_ok=True) envs.SGLANG_TORCH_PROFILER_DIR.set(OUTPUT_DIR) cls.model = LLAMA_3_2_1B_INSTRUCT_WEIGHTS_PATH cls.base_url = DEFAULT_URL_FOR_TEST diff --git a/test/registered/unit/configs/test_linear_attn_model_registry.py b/test/registered/unit/configs/test_linear_attn_model_registry.py new file mode 100644 index 000000000000..6fbd27d0ae7e --- /dev/null +++ b/test/registered/unit/configs/test_linear_attn_model_registry.py @@ -0,0 +1,161 @@ +"""Unit tests for srt/configs/linear_attn_model_registry.py""" + +import unittest + +from sglang.srt.configs.linear_attn_model_registry import ( + _LINEAR_ATTN_MODEL_REGISTRY, + LinearAttnModelSpec, + get_linear_attn_config, + get_linear_attn_spec_by_arch, + import_backend_class, + register_linear_attn_model, +) +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=5, suite="stage-a-test-cpu") + + +# Dummy config classes for testing +class FakeLinearAttnConfig: + full_attention_layer_ids = [0, 2, 4] + + +class FakeVLMWrapperConfig: + """Simulates a VLM wrapper that has get_text_config().""" + + def __init__(self): + self._text_config = FakeLinearAttnConfig() + + def get_text_config(self): + return self._text_config + + +class AnotherConfig: + pass + + +class TestLinearAttnModelRegistry(CustomTestCase): + def setUp(self): + # Save and clear the global registry between tests + self._saved_registry = list(_LINEAR_ATTN_MODEL_REGISTRY) + _LINEAR_ATTN_MODEL_REGISTRY.clear() + + def tearDown(self): + _LINEAR_ATTN_MODEL_REGISTRY.clear() + _LINEAR_ATTN_MODEL_REGISTRY.extend(self._saved_registry) + + def _make_spec(self, **overrides): + defaults = dict( + config_class=FakeLinearAttnConfig, + backend_class_name="sglang.srt.layers.attention.triton_backend.TritonAttnBackend", + arch_names=["FakeModelForCausalLM"], + ) + defaults.update(overrides) + return LinearAttnModelSpec(**defaults) + + def test_register_and_lookup_by_config(self): + spec = self._make_spec() + register_linear_attn_model(spec) + + hf_config = FakeLinearAttnConfig() + result = get_linear_attn_config(hf_config) + self.assertIsNotNone(result) + self.assertIs(result[0], spec) + self.assertIs(result[1], hf_config) + + def test_lookup_no_match(self): + spec = self._make_spec() + register_linear_attn_model(spec) + + result = get_linear_attn_config(AnotherConfig()) + self.assertIsNone(result) + + def test_lookup_empty_registry(self): + result = get_linear_attn_config(FakeLinearAttnConfig()) + self.assertIsNone(result) + + def test_unwrap_text_config(self): + spec = self._make_spec(unwrap_text_config=True) + register_linear_attn_model(spec) + + vlm_config = FakeVLMWrapperConfig() + result = get_linear_attn_config(vlm_config) + self.assertIsNotNone(result) + self.assertIs(result[0], spec) + # The resolved config should be the inner text config + self.assertIsInstance(result[1], FakeLinearAttnConfig) + self.assertIs(result[1], vlm_config._text_config) + + def test_unwrap_text_config_no_match(self): + """unwrap_text_config=False should not call get_text_config().""" + spec = self._make_spec(unwrap_text_config=False) + register_linear_attn_model(spec) + + vlm_config = FakeVLMWrapperConfig() + # VLM wrapper itself is not a FakeLinearAttnConfig, so no match + result = get_linear_attn_config(vlm_config) + self.assertIsNone(result) + + def test_lookup_by_arch(self): + spec = self._make_spec(arch_names=["AlphaForCausalLM", "BetaForCausalLM"]) + register_linear_attn_model(spec) + + self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec) + self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec) + self.assertIsNone(get_linear_attn_spec_by_arch("GammaForCausalLM")) + + def test_lookup_by_arch_empty_registry(self): + self.assertIsNone(get_linear_attn_spec_by_arch("AnyArch")) + + def test_multiple_registrations(self): + spec_a = self._make_spec( + config_class=FakeLinearAttnConfig, + arch_names=["AlphaForCausalLM"], + ) + spec_b = self._make_spec( + config_class=AnotherConfig, + arch_names=["BetaForCausalLM"], + ) + register_linear_attn_model(spec_a) + register_linear_attn_model(spec_b) + + # Config-based lookup + self.assertIs(get_linear_attn_config(FakeLinearAttnConfig())[0], spec_a) + self.assertIs(get_linear_attn_config(AnotherConfig())[0], spec_b) + + # Arch-based lookup + self.assertIs(get_linear_attn_spec_by_arch("AlphaForCausalLM"), spec_a) + self.assertIs(get_linear_attn_spec_by_arch("BetaForCausalLM"), spec_b) + + def test_first_match_wins(self): + """When two specs match the same config class, the first registered wins.""" + spec1 = self._make_spec(backend_class_name="pkg.Backend1") + spec2 = self._make_spec(backend_class_name="pkg.Backend2") + register_linear_attn_model(spec1) + register_linear_attn_model(spec2) + + result = get_linear_attn_config(FakeLinearAttnConfig()) + self.assertIs(result[0], spec1) + + def test_import_backend_class(self): + # Import a real stdlib class to verify the mechanism + cls = import_backend_class("collections.OrderedDict") + from collections import OrderedDict + + self.assertIs(cls, OrderedDict) + + def test_spec_defaults(self): + spec = LinearAttnModelSpec( + config_class=FakeLinearAttnConfig, + backend_class_name="pkg.mod.Cls", + ) + self.assertEqual(spec.arch_names, []) + self.assertTrue(spec.uses_mamba_radix_cache) + self.assertTrue(spec.support_mamba_cache) + self.assertFalse(spec.support_mamba_cache_extra_buffer) + self.assertFalse(spec.unwrap_text_config) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/function_call/test_function_call_parser.py b/test/registered/unit/function_call/test_function_call_parser.py index c418b0866d0e..01aa99904072 100644 --- a/test/registered/unit/function_call/test_function_call_parser.py +++ b/test/registered/unit/function_call/test_function_call_parser.py @@ -6,6 +6,12 @@ from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector +from sglang.srt.function_call.gemma4_detector import ( + Gemma4Detector, + _parse_gemma4_args, + _parse_gemma4_array, + _parse_gemma4_value, +) from sglang.srt.function_call.gigachat3_detector import GigaChat3Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -4008,5 +4014,340 @@ def test_streaming_multiple_tool_calls_char_by_char_separator(self): self.assertEqual(cities, ["NYC", "LA"]) +class TestGemma4Detector(unittest.TestCase): + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + ), + ) + ] + self.detector = Gemma4Detector() + + def test_detect_and_parse(self): + text = 'Some text before <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, "Some text before ") + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + + def test_parse_streaming_increment(self): + chunks = [ + "Some text ", + "before <|tool", + "_call>call:get_we", + "ather{location:<|", # codespell:ignore + '"|>Tokyo<|"|>} after", + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + combined_normal_text = "".join(r.normal_text for r in all_results) + self.assertEqual(combined_normal_text, "Some text before after") + + found_name = False + found_params = False + for res in all_results: + for call in res.calls: + if call.name == "get_weather": + found_name = True + if call.parameters: + params = json.loads(call.parameters) + if params == {"location": "Tokyo"}: + found_params = True + + self.assertTrue(found_name) + self.assertTrue(found_params) + + def test_nested_array_streaming(self): + # Additional coverage for complex structure + chunks = [ + '<|tool_call>call:get_weather{location:<|"', + '|>New York<|"|>,nested:[1, 2, {inner:<|"|>', + 'val<|"|>}]}', + ] + + all_results = [] + for chunk in chunks: + res = self.detector.parse_streaming_increment(chunk, self.tools) + all_results.append(res) + + found_params = False + for res in all_results: + for call in res.calls: + if call.parameters: + params = json.loads(call.parameters) + if "location" in params and params["location"] == "New York": + if "nested" in params and params["nested"] == [ + 1, + 2, + {"inner": "val"}, + ]: + found_params = True + + self.assertTrue(found_params) + + def test_has_tool_call(self): + self.assertTrue( + self.detector.has_tool_call( + '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ) + ) + self.assertFalse(self.detector.has_tool_call("no tool call here")) + + def test_detect_and_parse_no_tool_call(self): + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_detect_and_parse_tool_index(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, 0) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_detect_and_parse_unknown_tool_index(self): + text = '<|tool_call>call:unknown_func{arg:<|"|>val<|"|>}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].tool_index, -1) + + def test_detect_and_parse_nested_object(self): + text = '<|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>,details:{temp:25,unit:<|"|>celsius<|"|>}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + params = json.loads(result.calls[0].parameters) + self.assertEqual(params["location"], "Tokyo") + self.assertIsInstance(params["details"], dict) + self.assertEqual(params["details"]["temp"], 25) + self.assertEqual(params["details"]["unit"], "celsius") + + def test_detect_and_parse_multiple_calls(self): + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + text = ( + 'Some text <|tool_call>call:get_weather{location:<|"|>Tokyo<|"|>}' + ' more text <|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}' + ) + result = self.detector.detect_and_parse(text, extra_tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "get_time") + self.assertEqual(result.normal_text, "Some text ") + + def test_parse_gemma4_args_empty(self): + self.assertEqual(_parse_gemma4_args(""), {}) + self.assertEqual(_parse_gemma4_args(" "), {}) + + def test_parse_gemma4_args_booleans(self): + result = _parse_gemma4_args("flag:true,other:false") + self.assertIs(result["flag"], True) + self.assertIs(result["other"], False) + + def test_parse_gemma4_args_numbers(self): + result = _parse_gemma4_args("count:42,ratio:3.14") + self.assertEqual(result["count"], 42) + self.assertAlmostEqual(result["ratio"], 3.14) + + def test_parse_gemma4_args_string_with_colon(self): + result = _parse_gemma4_args('url:<|"|>http://example.com<|"|>') + self.assertEqual(result["url"], "http://example.com") + + def test_parse_gemma4_args_nested_object(self): + result = _parse_gemma4_args('outer:{inner:<|"|>val<|"|>,num:5}') + self.assertIsInstance(result["outer"], dict) + self.assertEqual(result["outer"]["inner"], "val") + self.assertEqual(result["outer"]["num"], 5) + + def test_parse_gemma4_array_mixed_types(self): + result = _parse_gemma4_array('<|"|>hello<|"|>, 42, true, {key:<|"|>val<|"|>}') + self.assertEqual(result[0], "hello") + self.assertEqual(result[1], 42) + self.assertIs(result[2], True) + self.assertIsInstance(result[3], dict) + self.assertEqual(result[3]["key"], "val") + + def test_parse_gemma4_value_types(self): + self.assertIs(_parse_gemma4_value("true"), True) + self.assertIs(_parse_gemma4_value("false"), False) + self.assertEqual(_parse_gemma4_value("42"), 42) + self.assertAlmostEqual(_parse_gemma4_value("3.14"), 3.14) + self.assertEqual(_parse_gemma4_value("hello"), "hello") + self.assertEqual(_parse_gemma4_value(""), "") + + def _collect_streaming(self, chunks): + """Helper: feed chunks and collect normal text + tool calls by index.""" + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + return normal_text, tool_calls_by_index + + def test_streaming_multiple_tool_calls(self): + """Test streaming with two consecutive tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + '<|tool_call>call:get_weather{location:<|"|>', + 'Tokyo<|"|>}', + ' <|tool_call>call:get_time{timezone:<|"|>', + 'UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Tokyo") + self.assertEqual(params1["timezone"], "UTC") + + def test_streaming_very_small_chunks(self): + """Test streaming with character-by-character chunks.""" + full_text = '<|tool_call>call:get_weather{location:<|"|>Rome<|"|>}' + chunks = list(full_text) + + normal_text, tool_calls = self._collect_streaming(chunks) + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + params = json.loads(tool_calls[0]["parameters"]) + self.assertEqual(params["location"], "Rome") + + def test_streaming_empty_args(self): + """Test streaming a tool call with no arguments.""" + chunks = ["<|tool_call>call:get_weather{}", ""] + normal_text, tool_calls = self._collect_streaming(chunks) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") + + def test_streaming_text_between_tool_calls(self): + """Test streaming with normal text interleaved between two different tool calls.""" + extra_tools = self.tools + [ + Tool( + type="function", + function=Function( + name="get_time", + description="Get current time", + parameters={ + "type": "object", + "properties": {"timezone": {"type": "string"}}, + }, + ), + ) + ] + chunks = [ + "Hello! ", + '<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}', + " Let me also check ", + '<|tool_call>call:get_time{timezone:<|"|>UTC<|"|>}', + ] + normal_text = "" + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, extra_tools) + normal_text += result.normal_text + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + self.assertIn("Hello!", normal_text) + self.assertIn("Let me also check", normal_text) + self.assertEqual(len(tool_calls_by_index), 2) + self.assertEqual(tool_calls_by_index[0]["name"], "get_weather") + self.assertEqual(tool_calls_by_index[1]["name"], "get_time") + params0 = json.loads(tool_calls_by_index[0]["parameters"]) + params1 = json.loads(tool_calls_by_index[1]["parameters"]) + self.assertEqual(params0["location"], "Paris") + self.assertEqual(params1["timezone"], "UTC") + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/parser/test_reasoning_parser.py b/test/registered/unit/parser/test_reasoning_parser.py index 5b9d623d51b7..8f05d7903e9b 100644 --- a/test/registered/unit/parser/test_reasoning_parser.py +++ b/test/registered/unit/parser/test_reasoning_parser.py @@ -5,6 +5,7 @@ from sglang.srt.parser.reasoning_parser import ( BaseReasoningFormatDetector, DeepSeekR1Detector, + Gemma4Detector, Glm45Detector, KimiDetector, KimiK2Detector, @@ -586,6 +587,141 @@ def test_force_nonempty_content_no_thinking_tokens(self): self.assertEqual(result.reasoning_text, "") +class TestGemma4Detector(CustomTestCase): + def setUp(self): + self.detector = Gemma4Detector() + + def test_init(self): + """Test Gemma4Detector initialization.""" + self.assertEqual(self.detector.think_start_token, "<|channel>") + self.assertEqual(self.detector.think_end_token, "") + self.assertEqual(self.detector.think_start_self_label, "thought\n") + self.assertFalse(self.detector._in_reasoning) + self.assertTrue(self.detector.stream_reasoning) + + def test_detect_and_parse_complete_reasoning(self): + """Test parsing complete Gemma4 reasoning block (think_start_self_label is stripped).""" + text = "<|channel>thought\nLet me think about thisThe answer is 42." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Let me think about this") + self.assertEqual(result.normal_text, "The answer is 42.") + + def test_detect_and_parse_without_thinking(self): + """Test parsing without thinking (enable_thinking=False case).""" + text = "Direct answer without thinking." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.normal_text, text) + self.assertEqual(result.reasoning_text, "") + + def test_detect_and_parse_reasoning_only(self): + """Test parsing when output is all reasoning (no end token yet).""" + text = "<|channel>thought\nStill thinking..." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Still thinking...") + self.assertEqual(result.normal_text, "") + + def test_streaming_complete_flow(self): + """Test streaming parse of Gemma4 reasoning flow.""" + chunks = [ + "<|channel>", + "thought\nreasoning content", + "", + "final answer", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + all_normal += result.normal_text + self.assertIn("reasoning content", all_reasoning) + self.assertIn("final answer", all_normal) + + def test_streaming_full_start_sequence(self): + """Test streaming with the full start sequence (token + self_label).""" + # Gemma4 start sequence is "<|channel>thought\n", not just "<|channel>" + result = self.detector.parse_streaming_increment("<|channel>thought\n") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + self.assertTrue(self.detector._in_reasoning) + + result = self.detector.parse_streaming_increment("reasoning content") + self.assertEqual(result.reasoning_text, "reasoning content") + self.assertEqual(result.normal_text, "") + + def test_streaming_partial_start_buffered(self): + """Test that partial start sequence is buffered.""" + # "<|channel>" alone is a prefix of "<|channel>thought\n", so it's buffered + result = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result.normal_text, "") + self.assertEqual(result.reasoning_text, "") + + def test_streaming_end_token_mid_chunk(self): + """Test end token arriving in the same chunk as reasoning content.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + result = self.detector.parse_streaming_increment( + "some reasoningthe answer" + ) + self.assertEqual(result.reasoning_text, "some reasoning") + self.assertEqual(result.normal_text, "the answer") + self.assertFalse(self.detector._in_reasoning) + + def test_streaming_split_end_token(self): + """Test end token split across two chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + self.detector.parse_streaming_increment("reasoning content") + + result1 = self.detector.parse_streaming_increment("final answer") + self.assertFalse(self.detector._in_reasoning) + self.assertIn("final answer", result2.normal_text) + + def test_streaming_self_label_split_across_chunks(self): + """Test self_label ('thought\\n') arriving separately from start token.""" + result1 = self.detector.parse_streaming_increment("<|channel>") + self.assertEqual(result1.reasoning_text, "") + self.assertEqual(result1.normal_text, "") + + result2 = self.detector.parse_streaming_increment("thought\n") + self.assertTrue(self.detector._in_reasoning) + + result3 = self.detector.parse_streaming_increment("reasoning here") + self.assertEqual(result3.reasoning_text, "reasoning here") + + def test_streaming_force_reasoning(self): + """Test streaming with force_reasoning=True (no start token needed).""" + detector = Gemma4Detector(force_reasoning=True) + + result1 = detector.parse_streaming_increment("reasoning content") + self.assertEqual(result1.reasoning_text, "reasoning content") + self.assertEqual(result1.normal_text, "") + + result2 = detector.parse_streaming_increment("the answer") + self.assertFalse(detector._in_reasoning) + self.assertIn("the answer", result2.normal_text) + + def test_streaming_multiple_reasoning_chunks(self): + """Test reasoning content arriving in many small chunks.""" + self.detector.parse_streaming_increment("<|channel>thought\n") + + all_reasoning = "" + for chunk in ["Think", "ing ", "step ", "by ", "step."]: + result = self.detector.parse_streaming_increment(chunk) + all_reasoning += result.reasoning_text + self.assertEqual(result.normal_text, "") + self.assertEqual(all_reasoning, "Thinking step by step.") + + def test_force_reasoning(self): + """Test Gemma4Detector with force_reasoning=True.""" + detector = Gemma4Detector(force_reasoning=True) + text = "This should be reasoningThe answer." + result = detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "This should be reasoning") + self.assertEqual(result.normal_text, "The answer.") + + class TestReasoningParser(CustomTestCase): def test_init_valid_model(self): """Test initialization with valid model types.""" @@ -604,6 +740,9 @@ def test_init_valid_model(self): parser = ReasoningParser("glm45") self.assertIsInstance(parser.detector, Glm45Detector) + parser = ReasoningParser("gemma4") + self.assertIsInstance(parser.detector, Gemma4Detector) + def test_init_invalid_model(self): """Test initialization with invalid model type.""" with self.assertRaises(ValueError) as context: @@ -782,6 +921,35 @@ def test_kimi_streaming_scenario(self): self.assertIn("multiple factors", all_reasoning) self.assertIn("42", all_normal) + def test_gemma4_complete_response(self): + """Test complete Gemma4 response parsing (think_start_self_label stripped).""" + parser = ReasoningParser("gemma4") + text = "<|channel>thought\nI need to solve x + 2 = 5. Subtracting 2: x = 3.The answer is x = 3." + reasoning, normal = parser.parse_non_stream(text) + self.assertIn("x = 3", reasoning) + self.assertNotIn("thought\n", reasoning) + self.assertEqual(normal, "The answer is x = 3.") + + def test_gemma4_streaming_scenario(self): + """Test Gemma4 streaming scenario.""" + parser = ReasoningParser("gemma4") + chunks = [ + "<|channel>", + "thought\nLet me analyze.", + " Multiple factors.", + "", + "The solution is 42.", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + reasoning, normal = parser.parse_stream_chunk(chunk) + all_reasoning += reasoning + all_normal += normal + self.assertIn("analyze", all_reasoning) + self.assertIn("Multiple factors", all_reasoning) + self.assertIn("42", all_normal) + def test_empty_reasoning_blocks(self): """Test handling of empty reasoning blocks.""" parser = ReasoningParser("qwen3")