diff --git a/benchmark/kernels/fused_moe_triton/common_utils.py b/benchmark/kernels/fused_moe_triton/common_utils.py index ca251cbde041..64189aa2a871 100644 --- a/benchmark/kernels/fused_moe_triton/common_utils.py +++ b/benchmark/kernels/fused_moe_triton/common_utils.py @@ -129,6 +129,10 @@ def get_model_config( E = config.num_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + elif architecture == "HYV3ForCausalLM": + E = config.num_experts // ep_size + topk = config.num_experts_per_tok + intermediate_size = config.expert_hidden_dim elif architecture == "NemotronHForCausalLM": E = config.n_routed_experts // ep_size topk = config.num_experts_per_tok diff --git a/docs/basic_usage/hy3_preview.md b/docs/basic_usage/hy3_preview.md new file mode 100644 index 000000000000..b7f23937ef72 --- /dev/null +++ b/docs/basic_usage/hy3_preview.md @@ -0,0 +1,191 @@ +# Hy3-preview Usage + +Hy3-preview is a large-scale language model (295B parameters, 21B active parameters) from Tencent Hunyuan team. SGLang supports serving Hy3-preview. This guide describes how to run Hy3-preview with native BF16. + +## Installation + +### Docker + +```bash +docker pull lmsysorg/sglang:hy3-preview +``` + +### Build from Source + +```bash +# Install SGLang +git clone https://github.com/sgl-project/sglang +cd sglang +pip3 install pip --upgrade +pip3 install "transformers>=5.6.0" +pip3 install -e "python" +``` + +## Launch Hy3-preview with SGLang + +To serve the [Hy3-preview](https://huggingface.co/tencent/Hy3-preview) model on 8 GPUs. On 8x96GB H20, SGLang can barely deploy the BF16 model and can only run small batch sizes or short requests. Use larger-memory GPUs such as H20-3e when possible. + +```bash +python3 -m sglang.launch_server \ + --model tencent/Hy3-preview \ + --tp 8 \ + --tool-call-parser hunyuan \ + --reasoning-parser hunyuan \ + --served-model-name hy3-preview +``` + +### EAGLE Speculative Decoding + +**Description**: SGLang supports Hy3-preview models with [EAGLE speculative decoding](https://docs.sglang.io/advanced_features/speculative_decoding.html#eagle-decoding). + +**Usage**: +Add `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk`, and `--speculative-num-draft-tokens` to enable this feature. For example: + +```bash +python3 -m sglang.launch_server \ + --model tencent/Hy3-preview \ + --tp 8 \ + --tool-call-parser hunyuan \ + --reasoning-parser hunyuan \ + --speculative-num-steps 1 \ + --speculative-eagle-topk 1 \ + --speculative-num-draft-tokens 2 \ + --speculative-algorithm EAGLE \ + --served-model-name hy3-preview +``` + +## OpenAI Client Example + +First, install the OpenAI Python client: + +```bash +uv pip install -U openai +``` + +You can use the OpenAI client as follows to verify thinking-mode responses. + +```python +from openai import OpenAI + +# If running SGLang locally with its default OpenAI-compatible port: +# http://localhost:30000/v1 +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:30000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello."}, +] + +# Thinking mode is disabled by default (no need to pass chat_template_kwargs). +resp = client.chat.completions.create( + model="hy3-preview", + messages=messages, + temperature=1, + max_tokens=4096, +) +print(resp.choices[0].message.content) + +# Thinking mode is enabled only if 'reasoning_effort' and 'interleaved_thinking' are set in 'chat_template_kwargs'. +# 'reasoning_effort' supports: 'high', 'low', 'no_think'. +resp_think = client.chat.completions.create( + model="hy3-preview", + messages=messages, + temperature=1, + max_tokens=4096, + extra_body={ + "chat_template_kwargs": { + "reasoning_effort": "high", + "interleaved_thinking": True + }, + }, +) +output_msg = resp_think.choices[0].message +# thinking content +print(output_msg.reasoning_content) +# response content +print(output_msg.content) +``` + +### cURL Usage + +```bash +curl http://localhost:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "hy3-preview", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello."} + ], + "temperature": 1, + "max_tokens": 4096 + }' +``` + +## Benchmarking Results + +For benchmarking, disable prefix caching by adding `--disable-radix-cache` to the server command. + +The following example runs the benchmark on 8 H20 GPUs with 96 GB memory each. + +```bash +python3 -m sglang.bench_serving \ + --backend sglang \ + --flush-cache \ + --dataset-name random \ + --random-range-ratio 1.0 \ + --random-input-len 4096 \ + --random-output-len 4096 \ + --num-prompts 5 \ + --max-concurrency 1 \ + --output-file hy3_preview_h20.jsonl \ + --model tencent/Hy3-preview \ + --served-model-name hy3-preview +``` + +If successful, you will see the following output. + +```shell +============ Serving Benchmark Result ============ +Backend: sglang +Traffic request rate: inf +Max request concurrency: 1 +Successful requests: 5 +Benchmark duration (s): 176.41 +Total input tokens: 20480 +Total input text tokens: 20480 +Total generated tokens: 20480 +Total generated tokens (retokenized): 20480 +Request throughput (req/s): 0.03 +Input token throughput (tok/s): 116.09 +Output token throughput (tok/s): 116.09 +Peak output token throughput (tok/s): 118.00 +Peak concurrent requests: 2 +Total token throughput (tok/s): 232.19 +Concurrency: 1.00 +----------------End-to-End Latency---------------- +Mean E2E Latency (ms): 35279.06 +Median E2E Latency (ms): 35275.60 +P90 E2E Latency (ms): 35294.13 +P99 E2E Latency (ms): 35294.41 +---------------Time to First Token---------------- +Mean TTFT (ms): 355.93 +Median TTFT (ms): 309.28 +P99 TTFT (ms): 518.36 +-----Time per Output Token (excl. 1st token)------ +Mean TPOT (ms): 8.53 +Median TPOT (ms): 8.54 +P99 TPOT (ms): 8.54 +---------------Inter-Token Latency---------------- +Mean ITL (ms): 8.53 +Median ITL (ms): 8.54 +P95 ITL (ms): 8.62 +P99 ITL (ms): 8.74 +Max ITL (ms): 31.70 +================================================== +``` diff --git a/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh b/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh new file mode 100644 index 000000000000..19677e0e7ed6 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh @@ -0,0 +1,267 @@ +/* + * Fused grouped top-k kernel for MoE routing. + * Adapted from vLLM's grouped_topk_kernels.cu (Apache-2.0). + * + * Handles single-group (num_expert_group=1) and multi-group cases with + * sigmoid scoring, bias correction, renormalization and scaling factor. + * Supports up to 512 experts and topk up to 8. + */ +#include // For TensorMatcher, SymbolicSize, SymbolicDevice +#include // For RuntimeCheck, div_ceil + +#include // For LaunchKernel, fp32_t + +#include +#include + +#include +#include + +namespace { + +static constexpr int WARP_SIZE = 32; +static constexpr int MAX_TOPK = 8; + +// Pack (value, index) into a single uint64_t for warp-level max reduction. +// Uses IEEE 754 bit-trick: float bits are order-preserving for positive values. +// Since sigmoid + positive bias yields non-negative scores, this works correctly. +__device__ __forceinline__ uint64_t pack_val_idx(float val, int32_t idx) { + uint32_t val_bits = __float_as_uint(val); + // Flip sign bit so that comparison works for all floats + val_bits ^= ((val_bits >> 31) | 0x80000000u); + // Use (65535 - idx) so that smaller indices win ties + uint32_t idx_bits = static_cast(65535 - idx); + return (static_cast(val_bits) << 32) | idx_bits; +} + +__device__ __forceinline__ void unpack_val_idx(uint64_t packed, float& val, int32_t& idx) { + uint32_t idx_bits = static_cast(packed & 0xFFFFFFFF); + idx = static_cast(65535 - idx_bits); + uint32_t val_bits = static_cast(packed >> 32); + // Undo the sign-bit flip + val_bits ^= (~(val_bits >> 31) | 0x80000000u); + val = __uint_as_float(val_bits); +} + +__device__ __forceinline__ uint64_t warp_max_u64(uint64_t val) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + uint64_t other = __shfl_xor_sync(0xffffffff, val, mask); + val = max(val, other); + } + return val; +} + +__device__ __forceinline__ float warp_sum_f32(float val) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, mask); + } + return val; +} + +__device__ __forceinline__ float fast_sigmoid(float x) { + return 1.0f / (1.0f + __expf(-x)); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Kernel: one block per token, MaxExperts threads per block. +// Each thread handles one expert (or is idle if threadIdx.x >= numExperts). +// +// Phase 1: All threads load score → sigmoid → +bias → shared memory. +// Phase 2: Warp 0 iteratively selects top-k via packed warp-level max reduce. +// Phase 3: Warp 0 renormalizes and writes output. +// ───────────────────────────────────────────────────────────────────────────── +template +__global__ void grouped_topk_single_group_kernel( + const float* __restrict__ scores, + float* __restrict__ topk_values, + int32_t* __restrict__ topk_indices, + const float* __restrict__ bias, + int64_t num_tokens, + int64_t num_experts, + int64_t topk, + bool renormalize, + float scaling_factor) { + __shared__ float smem_sigmoid[MaxExperts]; + __shared__ float smem_biased[MaxExperts]; + + int64_t token_id = blockIdx.x; + if (token_id >= num_tokens) return; + + int tid = threadIdx.x; + const float* token_scores = scores + token_id * num_experts; + + // Phase 1: load → sigmoid → bias → shared memory + float score_sig = -FLT_MAX; + float score_biased = -FLT_MAX; + if (tid < num_experts) { + float raw = token_scores[tid]; + score_sig = fast_sigmoid(raw); + score_biased = score_sig + bias[tid]; + } + smem_sigmoid[tid] = score_sig; + smem_biased[tid] = score_biased; + __syncthreads(); + + // Phase 2 & 3: warp 0 selects top-k + int warp_id = tid / WARP_SIZE; + int lane_id = tid % WARP_SIZE; + + if (warp_id != 0) return; + + float* out_vals = topk_values + token_id * topk; + int32_t* out_ids = topk_indices + token_id * topk; + + // Each lane scans ceil(num_experts/32) experts per iteration + float selected_weights[MAX_TOPK]; + int32_t selected_ids[MAX_TOPK]; + + for (int k = 0; k < topk; k++) { + // Each lane finds its local max among its assigned experts + float my_max_val = -FLT_MAX; + int32_t my_max_idx = 0; + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + float v = smem_biased[i]; + if (v > my_max_val) { + my_max_val = v; + my_max_idx = i; + } + } + + // Warp-level max reduction using packed value+index + uint64_t packed = pack_val_idx(my_max_val, my_max_idx); + uint64_t best = warp_max_u64(packed); + + float best_val; + int32_t best_idx; + unpack_val_idx(best, best_val, best_idx); + + selected_ids[k] = best_idx; + selected_weights[k] = smem_sigmoid[best_idx]; + + // Mark selected expert so it won't be picked again + if (lane_id == best_idx % WARP_SIZE && (best_idx / WARP_SIZE) == 0) { + smem_biased[best_idx] = -FLT_MAX; + } + // Handle indices >= 32: the owning lane must clear it + if (best_idx >= WARP_SIZE) { + if (lane_id == 0) { + smem_biased[best_idx] = -FLT_MAX; + } + } else { + if (lane_id == best_idx) { + smem_biased[best_idx] = -FLT_MAX; + } + } + __syncwarp(); + } + + // Phase 3: renormalize and write output + if (lane_id < topk) { + float weight = selected_weights[lane_id]; + float final_weight = weight * scaling_factor; + + if (renormalize) { + // Warp-level sum of selected weights (only lanes < topk contribute) + float partial = (lane_id < topk) ? weight : 0.0f; + float total = warp_sum_f32(partial); + final_weight = weight * scaling_factor / (total + 1e-20f); + } + + out_ids[lane_id] = selected_ids[lane_id]; + out_vals[lane_id] = final_weight; + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Launcher +// ───────────────────────────────────────────────────────────────────────────── +void grouped_topk( + tvm::ffi::TensorView scores, + tvm::ffi::TensorView bias, + tvm::ffi::TensorView topk_values, + tvm::ffi::TensorView topk_indices, + int64_t num_expert_group, + int64_t topk_group, + int64_t topk, + bool renormalize, + double scaling_factor) { + using namespace host; + + SymbolicSize N{"num_tokens"}; + SymbolicSize E{"num_experts"}; + SymbolicDevice device_; + device_.set_options(); + + TensorMatcher({N, E}).with_dtype().with_device(device_).verify(scores); + + TensorMatcher({E}).with_dtype().with_device(device_).verify(bias); + + SymbolicSize K{"topk"}; + TensorMatcher({N, K}).with_dtype().with_device(device_).verify(topk_values); + + TensorMatcher({N, K}).with_dtype().with_device(device_).verify(topk_indices); + + int64_t num_tokens = N.unwrap(); + int64_t num_experts = E.unwrap(); + DLDevice device = device_.unwrap(); + + RuntimeCheck(num_expert_group == 1 && topk_group == 1, "This kernel only supports num_expert_group=1, topk_group=1"); + RuntimeCheck(topk <= MAX_TOPK, "topk must be <= ", MAX_TOPK); + RuntimeCheck(num_experts <= 512, "num_experts must be <= 512"); + + if (num_tokens == 0) return; + + float scale_f = static_cast(scaling_factor); + + auto* score_ptr = static_cast(scores.data_ptr()); + auto* bias_ptr = static_cast(bias.data_ptr()); + auto* val_ptr = static_cast(topk_values.data_ptr()); + auto* idx_ptr = static_cast(topk_indices.data_ptr()); + + // Select template based on expert count (round up to next tier) + int num_threads; + if (num_experts <= 128) { + num_threads = 128; + LaunchKernel(static_cast(num_tokens), num_threads, device)( + grouped_topk_single_group_kernel<128>, + score_ptr, + val_ptr, + idx_ptr, + bias_ptr, + num_tokens, + num_experts, + topk, + renormalize, + scale_f); + } else if (num_experts <= 256) { + num_threads = 256; + LaunchKernel(static_cast(num_tokens), num_threads, device)( + grouped_topk_single_group_kernel<256>, + score_ptr, + val_ptr, + idx_ptr, + bias_ptr, + num_tokens, + num_experts, + topk, + renormalize, + scale_f); + } else { + num_threads = 512; + LaunchKernel(static_cast(num_tokens), num_threads, device)( + grouped_topk_single_group_kernel<512>, + score_ptr, + val_ptr, + idx_ptr, + bias_ptr, + num_tokens, + num_experts, + topk, + renormalize, + scale_f); + } +} + +} // namespace diff --git a/python/sglang/jit_kernel/grouped_topk.py b/python/sglang/jit_kernel/grouped_topk.py new file mode 100644 index 000000000000..dae4b5a7b00a --- /dev/null +++ b/python/sglang/jit_kernel/grouped_topk.py @@ -0,0 +1,89 @@ +"""Fused grouped top-k kernel for MoE routing (single-group, sigmoid scoring).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_grouped_topk_module() -> Module: + return load_jit( + "grouped_topk", + cuda_files=["moe/grouped_topk.cuh"], + cuda_wrappers=[("grouped_topk", "grouped_topk")], + ) + + +@register_custom_op(mutates_args=["topk_values", "topk_indices"]) +def _jit_grouped_topk_op( + scores: torch.Tensor, + bias: torch.Tensor, + topk_values: torch.Tensor, + topk_indices: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + renormalize: bool, + scaling_factor: float, +) -> None: + module = _jit_grouped_topk_module() + module.grouped_topk( + scores, + bias, + topk_values, + topk_indices, + num_expert_group, + topk_group, + topk, + renormalize, + scaling_factor, + ) + + +def grouped_topk( + scores: torch.Tensor, + bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + renormalize: bool, + scaling_factor: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused sigmoid + bias + top-k + renormalize for MoE routing. + + Replaces the naive PyTorch path that uses 3x torch.topk + scatter + masked_fill. + Currently supports num_expert_group=1, topk_group=1, num_experts<=512, topk<=8. + """ + num_tokens = scores.shape[0] + + topk_values = torch.empty( + (num_tokens, topk), dtype=torch.float32, device=scores.device + ) + topk_indices = torch.empty( + (num_tokens, topk), dtype=torch.int32, device=scores.device + ) + + if num_tokens == 0: + return topk_values, topk_indices + + _jit_grouped_topk_op( + scores.contiguous(), + bias.contiguous(), + topk_values, + topk_indices, + num_expert_group, + topk_group, + topk, + renormalize, + scaling_factor, + ) + return topk_values, topk_indices diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f81eecaf1469..7100f1276407 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -366,6 +366,10 @@ def _config_draft_model(self): self.hf_config.architectures[0] = "NemotronHForCausalLMMTP" self.hf_config.num_nextn_predict_layers = 1 + if is_draft_model and self.hf_config.architectures[0] == "HYV3ForCausalLM": + self.hf_config.architectures[0] = "HYV3ForCausalLMNextN" + self.hf_config.num_nextn_predict_layers = 1 + def _derive_hybrid_model(self): # Use self.context_len after it has been initialized to prevent using context_len which may be None. self.is_hybrid_swa = ( diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 406314925587..650b5dcb29b9 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1404,9 +1404,14 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: request.chat_template_kwargs is not None and request.chat_template_kwargs.get("enable_thinking") is True ) - if self.reasoning_parser in ["mistral"]: - # Mistral models only reason when reasoning_effort is explicitly - # set to a value other than None/"none" (typically "high"). + if self.reasoning_parser == "hunyuan": + # Hy3-preview template emits no when reasoning_effort is + # "no_think" / "none" / unset; forcing reasoning would route all + # output into reasoning_content. + return request.reasoning_effort not in (None, "none", "no_think") + if self.reasoning_parser == "mistral": + # Mistral only reasons when reasoning_effort is explicitly set + # to a non-"none" value (typically "high"). return ( request.reasoning_effort is not None and request.reasoning_effort != "none" diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 4350d725fda4..08585e556c5c 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -21,6 +21,7 @@ from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector from sglang.srt.function_call.gpt_oss_detector import GptOssDetector from sglang.srt.function_call.hermes_detector import HermesDetector +from sglang.srt.function_call.hunyuan_detector import HunyuanDetector from sglang.srt.function_call.internlm_detector import InternlmDetector from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.lfm2_detector import Lfm2Detector @@ -73,6 +74,7 @@ class FunctionCallParser: "trinity": TrinityDetector, "interns1": InternlmDetector, "hermes": HermesDetector, + "hunyuan": HunyuanDetector, "gigachat3": GigaChat3Detector, "gemma4": Gemma4Detector, } diff --git a/python/sglang/srt/function_call/hunyuan_detector.py b/python/sglang/srt/function_call/hunyuan_detector.py new file mode 100644 index 000000000000..67b0187e5cb0 --- /dev/null +++ b/python/sglang/srt/function_call/hunyuan_detector.py @@ -0,0 +1,476 @@ +import json +import logging +import re +from typing import Any, Dict, List, Optional, Set + +from sglang.srt.entrypoints.openai.protocol import Tool +from sglang.srt.environ import envs +from sglang.srt.function_call.base_format_detector import BaseFormatDetector +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + StructureInfo, + ToolCallItem, + _GetInfoFunc, +) + +logger = logging.getLogger(__name__) + + +class HunyuanDetector(BaseFormatDetector): + """ + Detector for Hunyuan (HYV3) tool call format. + + Format: + + function_name + key1 + value1 + + + + Streaming behavior: + * Phase 1 emits the tool name once is seen. + * Phase 2 streams argument JSON incrementally. Closed + pairs are parsed with schema-aware type coercion; pure-string + args may be streamed char-by-char (with JSON escaping). The + closing "}" is withheld until arrives. + """ + + _TYPE_ALIASES: Dict[str, str] = { + "str": "string", + "text": "string", + "varchar": "string", + "char": "string", + "enum": "string", + "bool": "boolean", + "binary": "boolean", + "int": "integer", + "float": "number", + "double": "number", + "list": "array", + "dict": "object", + "map": "object", + } + + _INTEGER_PREFIXES = ("int", "uint", "long", "short", "unsigned") + _NUMBER_PREFIXES = ("num", "float") + + def __init__(self): + super().__init__() + + self.bot_token = "" + self.eot_token = "" + + self.tool_call_start_token = "" + self.tool_call_end_token = "" + self.tool_sep_token = "" + + self.arg_key_start_token = "" + self.arg_key_end_token = "" + self.arg_value_start_token = "" + self.arg_value_end_token = "" + + self.tool_call_regex = re.compile( + r"(.*?)(.*?)", re.DOTALL + ) + self.func_args_regex = re.compile( + r"(.*?)\s*(.*?)", re.DOTALL + ) + + # Streaming state + self._in_tool_calls: bool = False + self._streaming_tool_name: Optional[str] = None + self._completed_args: Dict[str, Any] = {} + self._streamed_json_len: int = 0 + + # ------------------------------------------------------------------ + # Type-normalization helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_type(raw_type: str) -> str: + exact = HunyuanDetector._TYPE_ALIASES.get(raw_type) + if exact is not None: + return exact + lower = raw_type.lower() + if any(lower.startswith(p) for p in HunyuanDetector._INTEGER_PREFIXES): + return "integer" + if any(lower.startswith(p) for p in HunyuanDetector._NUMBER_PREFIXES): + return "number" + return raw_type + + @staticmethod + def _get_arg_schema( + function_name: str, arg_key: str, tools: Optional[List[Tool]] + ) -> dict: + if not tools: + return {} + for tool in tools: + if tool.function.name == function_name: + if tool.function.parameters is None: + return {} + return tool.function.parameters.get("properties", {}).get(arg_key, {}) + return {} + + @staticmethod + def _get_schema_options(arg_schema: dict) -> List[dict]: + """Priority: single ``type`` > ``anyOf`` > ``oneOf``; else default string.""" + if "type" in arg_schema: + return [arg_schema] + if "anyOf" in arg_schema: + return arg_schema["anyOf"] + if "oneOf" in arg_schema: + return arg_schema["oneOf"] + return [{"type": "string"}] + + @staticmethod + def _get_types(arg_schema: dict) -> Set[str]: + schemas = HunyuanDetector._get_schema_options(arg_schema) + return { + HunyuanDetector._normalize_type(s.get("type", "string")) for s in schemas + } - {"null"} + + @staticmethod + def _is_only_string_type( + function_name: str, arg_key: str, tools: Optional[List[Tool]] + ) -> bool: + """Only pure-string args get char-by-char value streaming; compound + types like anyOf(string | array) might resolve to a JSON array or + object, so we can't safely stream them as open JSON strings.""" + arg_schema = HunyuanDetector._get_arg_schema(function_name, arg_key, tools) + return HunyuanDetector._get_types(arg_schema) == {"string"} + + @staticmethod + def _try_parse_bool(value: str) -> Optional[bool]: + lower = value.lower() + if lower == "true": + return True + if lower == "false": + return False + return None + + @staticmethod + def _try_parse_int(value: str) -> Optional[int]: + try: + return int(value) + except (ValueError, TypeError): + return None + + @staticmethod + def _try_parse_number(value: str): + """int if no '.'/'e'/'E', else float.""" + try: + if "." in value or "e" in value or "E" in value: + return float(value) + return int(value) + except (ValueError, TypeError): + return None + + @staticmethod + def _deserialize(value: str) -> Any: + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + return value + + @staticmethod + def _parse_value( + value: str, + function_name: str, + arg_key: str, + tools: Optional[List[Tool]], + ) -> Any: + """Unified value parser: bool → int → number → json (array/obj) → string.""" + arg_schema = HunyuanDetector._get_arg_schema(function_name, arg_key, tools) + types = HunyuanDetector._get_types(arg_schema) + + if "boolean" in types: + r = HunyuanDetector._try_parse_bool(value) + if r is not None: + return r + + if "integer" in types: + r = HunyuanDetector._try_parse_int(value) + if r is not None: + return r + + if "number" in types: + r = HunyuanDetector._try_parse_number(value) + if r is not None: + return r + + if types - {"string", "boolean", "integer", "number"}: + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + pass + + if "string" in types: + return value + + return HunyuanDetector._deserialize(value) + + # ------------------------------------------------------------------ + # Non-streaming + # ------------------------------------------------------------------ + + def has_tool_call(self, text: str) -> bool: + return self.bot_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + if self.bot_token not in text: + return StreamingParseResult(normal_text=text, calls=[]) + + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx > 0 else "" + + tool_indices = self._get_tool_indices(tools) + forward_unknown = envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get() + + calls: List[ToolCallItem] = [] + try: + for function_name, function_args in self.tool_call_regex.findall(text): + function_name = function_name.strip() + if function_name not in tool_indices and not forward_unknown: + logger.warning( + "Model attempted to call undefined function: %s", function_name + ) + continue + + arg_dict: Dict[str, Any] = {} + for key, value in self.func_args_regex.findall(function_args): + key = key.strip() + arg_dict[key] = self._parse_value(value, function_name, key, tools) + + calls.append( + ToolCallItem( + tool_index=tool_indices.get(function_name, -1), + name=function_name, + parameters=json.dumps(arg_dict, ensure_ascii=False), + ) + ) + return StreamingParseResult(normal_text=normal_text, calls=calls) + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) + return StreamingParseResult(normal_text=text) + + # ------------------------------------------------------------------ + # Streaming + # ------------------------------------------------------------------ + + def _reset_streaming_tool_state(self): + self._streaming_tool_name = None + self._completed_args = {} + self._streamed_json_len = 0 + + def parse_streaming_increment( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + try: + return self._parse_streaming_increment_impl(new_text, tools) + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + return StreamingParseResult() + + def _parse_streaming_increment_impl( + self, new_text: str, tools: List[Tool] + ) -> StreamingParseResult: + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + # Not yet inside : emit normal text or buffer partial bot_token. + if not self._in_tool_calls: + combined = self._buffer + new_text + if self.bot_token in combined: + bot_pos = combined.find(self.bot_token) + normal_text = combined[:bot_pos] + self._buffer = combined[bot_pos + len(self.bot_token) :] + self._in_tool_calls = True + return self._continue_streaming(tools, leading_normal=normal_text) + + partial_len = self._ends_with_partial_token(combined, self.bot_token) + if partial_len: + self._buffer = combined[-partial_len:] + return StreamingParseResult(normal_text=combined[:-partial_len]) + self._buffer = "" + return StreamingParseResult(normal_text=combined) + + self._buffer += new_text + return self._continue_streaming(tools) + + def _continue_streaming( + self, tools: List[Tool], leading_normal: str = "" + ) -> StreamingParseResult: + """Drive the state machine after is open.""" + calls: List[ToolCallItem] = [] + + while True: + if self._streaming_tool_name is None: + # Phase 1: wait for ... + tc_start = self._buffer.find(self.tool_call_start_token) + if tc_start == -1: + if self.eot_token in self._buffer: + eot_pos = self._buffer.find(self.eot_token) + self._buffer = self._buffer[eot_pos + len(self.eot_token) :] + self._in_tool_calls = False + break + + sep_pos = self._buffer.find(self.tool_sep_token, tc_start) + if sep_pos == -1: + self._buffer = self._buffer[tc_start:] + break + + tool_name = self._buffer[ + tc_start + len(self.tool_call_start_token) : sep_pos + ].strip() + + if ( + tool_name not in self._tool_indices + and not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get() + ): + logger.warning( + "Model attempted to call undefined function: %s", tool_name + ) + + self._streaming_tool_name = tool_name + self.current_tool_id += 1 + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=tool_name, + parameters="", + ) + ) + + self._buffer = self._buffer[sep_pos + len(self.tool_sep_token) :] + + # Phase 2: stream argument JSON of the current tool. + before_name = self._streaming_tool_name + calls.extend(self._stream_args(tools)) + if self._streaming_tool_name is not None: + break # current tool still open; need more data. + if self._streaming_tool_name == before_name: + break # safety: avoid infinite loop if state didn't advance. + + return StreamingParseResult(normal_text=leading_normal, calls=calls) + + def _stream_args(self, tools: List[Tool]) -> List[ToolCallItem]: + """Emit argument-JSON deltas for the currently-open tool call.""" + is_complete = self.tool_call_end_token in self._buffer + + if is_complete: + end_idx = self._buffer.find(self.tool_call_end_token) + args_text = self._buffer[:end_idx] + else: + args_text = self._buffer + + # 1. Absorb closed .. pairs. + last_closed_end = 0 + for m in self.func_args_regex.finditer(args_text): + key, value = m.groups() + key = key.strip() + if key not in self._completed_args: + self._completed_args[key] = self._parse_value( + value, self._streaming_tool_name or "", key, tools + ) + last_closed_end = m.end() + + # 2. Detect a partial (unclosed) kv pair at the tail. + tail = args_text[last_closed_end:] + partial_key: Optional[str] = None + partial_value: Optional[str] = None + + ak_start = tail.find(self.arg_key_start_token) + if ak_start != -1: + ak_end = tail.find( + self.arg_key_end_token, ak_start + len(self.arg_key_start_token) + ) + if ak_end != -1: + partial_key = tail[ + ak_start + len(self.arg_key_start_token) : ak_end + ].strip() + av_start = tail.find(self.arg_value_start_token, ak_end) + if av_start != -1 and self._is_only_string_type( + self._streaming_tool_name or "", partial_key, tools + ): + partial_value = tail[av_start + len(self.arg_value_start_token) :] + + # Avoid emitting a lone "{" before any arg content is knowable. + if not is_complete and not self._completed_args and partial_value is None: + return [] + + # 3. Build the JSON snapshot manually to control streaming boundaries. + snapshot_parts: List[str] = [] + for k, v in self._completed_args.items(): + k_json = json.dumps(k, ensure_ascii=False) + v_json = json.dumps(v, ensure_ascii=False) + snapshot_parts.append(f"{k_json}: {v_json}") + + if partial_key is not None and partial_value is not None: + # Hold back chars that could be a partial marker so + # that a `<` starting the end-tag doesn't leak into the streamed + # JSON string value. + hold = self._ends_with_partial_token( + partial_value, self.arg_value_end_token + ) + safe_value = partial_value[:-hold] if hold else partial_value + k_json = json.dumps(partial_key, ensure_ascii=False) + escaped = ( + safe_value.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + ) + # No closing `"` here — it's appended when the value closes. + snapshot_parts.append(f'{k_json}: "{escaped}') + + snapshot = "{" + ", ".join(snapshot_parts) + "}" + + argument_diff: Optional[str] = None + + if is_complete: + final_json = json.dumps(self._completed_args, ensure_ascii=False) + if self._streamed_json_len < len(final_json): + argument_diff = final_json[self._streamed_json_len :] + self._streamed_json_len = len(final_json) + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[self.current_tool_id] = { + "name": self._streaming_tool_name, + "arguments": dict(self._completed_args), + } + + end_idx = self._buffer.find(self.tool_call_end_token) + self._buffer = self._buffer[end_idx + len(self.tool_call_end_token) :] + self._reset_streaming_tool_state() + else: + # Withhold the trailing "}" while the tool call is still open. + end = len(snapshot) - 1 + if end > self._streamed_json_len: + argument_diff = snapshot[self._streamed_json_len : end] + self._streamed_json_len = end + + if argument_diff: + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + return [ + ToolCallItem( + tool_index=self.current_tool_id, + parameters=argument_diff, + ) + ] + return [] + + def structure_info(self) -> _GetInfoFunc: + return lambda name: StructureInfo( + begin=f"\n{name}", + end="\n", + trigger="", + ) + + def supports_structural_tag(self) -> bool: + return False diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..f721d128e84a --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..5db0a3a7e5e8 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8_down.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8_down.json new file mode 100644 index 000000000000..5db0a3a7e5e8 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20,dtype=fp8_w8a8_down.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..45db1ca16971 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8_down.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8_down.json new file mode 100644 index 000000000000..45db1ca16971 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8_down.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e.json new file mode 100644 index 000000000000..5137a8b74d79 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e_down.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e_down.json new file mode 100644 index 000000000000..5137a8b74d79 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20-3e_down.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..6b666f8fd727 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20_down.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20_down.json new file mode 100644 index 000000000000..6b666f8fd727 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=192,N=192,device_name=NVIDIA_H20_down.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 77660ed5571a..7a0807c39ece 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -901,6 +901,30 @@ def biased_grouped_topk_gpu( routed_scaling_factor=routed_scaling_factor, apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, ) + elif ( + _is_cuda + and num_expert_group == 1 + and topk_group == 1 + and num_fused_shared_experts == 0 + and num_experts <= 512 + and topk <= 8 + ): + from sglang.jit_kernel.grouped_topk import grouped_topk as jit_grouped_topk + + scaling = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) + if not apply_routed_scaling_factor_on_output: + scaling = 1.0 + return jit_grouped_topk( + gating_output.to(dtype=torch.float32), + correction_bias.to(dtype=torch.float32), + num_expert_group, + topk_group, + topk, + renormalize, + scaling, + ) else: return biased_grouped_topk_impl( hidden_states, diff --git a/python/sglang/srt/models/hunyuan_v3.py b/python/sglang/srt/models/hunyuan_v3.py new file mode 100644 index 000000000000..b0f87dc327ee --- /dev/null +++ b/python/sglang/srt/models/hunyuan_v3.py @@ -0,0 +1,587 @@ +# coding=utf-8 +# Copyright 2026 The HunYuan team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, + get_moe_tensor_parallel_world_size, + get_tensor_model_parallel_world_size, + moe_expert_parallel_all_reduce, + moe_tensor_model_parallel_all_reduce, +) +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import ForwardBatch +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import is_cuda +from sglang.srt.utils.hf_transformers_utils import get_rope_config + + +class HYV3FeedForward(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + out = self.act_fn(gate_up) + out, _ = self.down_proj(out) + return out + + +class HYV3MoEFused(nn.Module): + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ): + super().__init__() + self.tp_size = get_moe_tensor_parallel_world_size() + self.ep_size = get_moe_expert_parallel_world_size() + self.layer_id = layer_id + self.alt_stream = alt_stream + self.n_routed_experts = config.num_experts + top_k = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + + self.expert_bias = nn.Parameter( + torch.empty(config.num_experts, dtype=torch.float32) + ) + self.expert_bias.weight_loader = HYV3MoEFused.ebias_weight_loader + scoring_func = "sigmoid" + self.e_score_correction_bias = self.expert_bias + self.router_scaling_factor = getattr(config, "router_scaling_factor", 1.0) + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + params_dtype=torch.float32, + prefix=f"{prefix}.gate", + ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + use_grouped_topk=True, + num_expert_group=1, + topk_group=1, + renormalize=config.route_norm, + scoring_func=scoring_func, + correction_bias=self.e_score_correction_bias, + routed_scaling_factor=self.router_scaling_factor, + apply_routed_scaling_factor_on_output=True, + ) + + if getattr(config, "num_shared_experts", 0) > 0: + self.shared_mlp = HYV3FeedForward( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size + * config.num_shared_experts, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp", + reduce_results=False, + ) + else: + self.shared_mlp = None + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + layer_id=layer_id, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + @staticmethod + def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight.to(torch.float32)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if ( + self.alt_stream is not None + and self.shared_mlp is not None + and hidden_states.shape[0] > 0 + and get_is_capture_mode() + ): + return self._forward_dual_stream(hidden_states) + return self._forward_single_stream(hidden_states) + + def _forward_single_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + topk_output = self.topk(hidden_states, router_logits) + if self.shared_mlp is not None: + shared_output = self.shared_mlp(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, topk_output=topk_output + ) + final_hidden_states = final_hidden_states + shared_output + else: + final_hidden_states = self.experts( + hidden_states=hidden_states, topk_output=topk_output + ) + + if self.ep_size > 1: + final_hidden_states = moe_expert_parallel_all_reduce(final_hidden_states) + + if self.tp_size > 1: + final_hidden_states = moe_tensor_model_parallel_all_reduce( + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + def _forward_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Shared experts on main stream, routed experts on alt stream.""" + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + + shared_output = self.shared_mlp(hidden_states) + + with torch.cuda.stream(self.alt_stream): + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, topk_output=topk_output + ) + + current_stream.wait_stream(self.alt_stream) + final_hidden_states = final_hidden_states + shared_output + + if self.ep_size > 1: + final_hidden_states = moe_expert_parallel_all_reduce(final_hidden_states) + + if self.tp_size > 1: + final_hidden_states = moe_tensor_model_parallel_all_reduce( + final_hidden_states + ) + + return final_hidden_states.view(orig_shape) + + +class HYV3Attention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[dict] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = getattr(config, "head_dim", hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.use_qk_norm = getattr( + config, "use_qk_norm", getattr(config, "qk_norm", False) + ) + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=f"{prefix}.attn", + ) + if self.use_qk_norm: + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.q_norm = RMSNorm(self.head_dim, rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.use_qk_norm: + q = self.q_norm(q.reshape(-1, self.head_dim)) + q = q.view(-1, self.q_size) + k = self.k_norm(k.reshape(-1, self.head_dim)) + k = k.view(-1, self.kv_size) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class HYV3DecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = config.hidden_size + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + rope_theta, _ = get_rope_config(config) + self.self_attn = HYV3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) + + first_k_dense_replace = getattr(config, "first_k_dense_replace", 0) + if layer_id < first_k_dense_replace: + self.mlp = HYV3FeedForward( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.block_type = "feedforward" + else: + self.mlp = HYV3MoEFused( + config=config, + layer_id=layer_id, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + alt_stream=alt_stream, + ) + self.block_type = "moe" + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class HYV3Model(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + + self.alt_stream = torch.cuda.Stream() if is_cuda() else None + + self.layers = nn.ModuleList( + [ + HYV3DecoderLayer( + config=config, + layer_id=i, + quant_config=quant_config, + prefix=f"{prefix}.layers.{i}", + alt_stream=self.alt_stream, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class HYV3ForCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.model = HYV3Model(config, quant_config, prefix=f"{prefix}.model") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) + if getattr(self.config, "tie_word_embeddings", False): + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + num_nextn_layers = getattr(self.config, "num_nextn_predict_layers", 0) + + for name, loaded_weight in weights: + if "lm_head.weight" in name and getattr( + self.config, "tie_word_embeddings", False + ): + continue + + if "rotary_emb.inv_freq" in name: + continue + + if num_nextn_layers > 0 and name.startswith("model.layers."): + parts = name.split(".") + if len(parts) >= 3 and int(parts[2]) >= self.config.num_hidden_layers: + continue + + is_found = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + is_found = True + break + if is_found: + continue + + # Handle expert weights (including fp8 weight_scale, input_scale) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + ) + break + if is_expert_weight: + continue + + if "router.gate." in name: + name = name.replace("router.", "") + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [HYV3ForCausalLM] diff --git a/python/sglang/srt/models/hunyuan_v3_nextn.py b/python/sglang/srt/models/hunyuan_v3_nextn.py new file mode 100644 index 000000000000..6ed3384287b3 --- /dev/null +++ b/python/sglang/srt/models/hunyuan_v3_nextn.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2026 The HunYuan team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference-only HunyuanV3 NextN (MTP) Speculative Decoding.""" + +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.hunyuan_v3 import HYV3DecoderLayer +from sglang.srt.utils import is_cuda + +logger = logging.getLogger(__name__) + + +class HYV3ModelNextN(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + + self.alt_stream = torch.cuda.Stream() if is_cuda() else None + + # Force MoE for the MTP layer: first_k_dense_replace=1 would make + # layer_id=0 pick a dense MLP instead of MoE, so override it. + orig_first_k = getattr(config, "first_k_dense_replace", 0) + config.first_k_dense_replace = 0 + self.decoder = HYV3DecoderLayer( + config=config, + layer_id=0, + quant_config=quant_config, + prefix=f"{prefix}.decoder", + alt_stream=self.alt_stream, + ) + config.first_k_dense_replace = orig_first_k + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + hidden_states = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) + ) + + residual = None + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is not None: + hidden_states, _ = self.shared_head.norm(hidden_states, residual) + else: + hidden_states = self.shared_head.norm(hidden_states) + + return hidden_states + + +class HYV3ForCausalLMNextN(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + + self.model = HYV3ModelNextN(config, quant_config, prefix="model") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + nextn_layer_id = self.config.num_hidden_layers + nextn_prefix = f"model.layers.{nextn_layer_id}." + spec_weight_names = ("enorm", "hnorm", "eh_proj") + + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if name.startswith(nextn_prefix): + subname = name[len(nextn_prefix) :] + if any(subname.startswith(s) for s in spec_weight_names): + name = f"model.{subname}" + else: + name = f"model.decoder.{subname}" + elif name == "model.shared_head.norm.weight": + pass + elif ( + "embed_tokens" in name + or "shared_head.head" in name + or "lm_head" in name + ): + continue + else: + continue + + if "rotary_emb.inv_freq" in name: + continue + + if "router.gate." in name: + name = name.replace("router.", "") + + is_found = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + is_found = True + break + if is_found: + continue + + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + ) + break + if is_expert_weight: + continue + + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [HYV3ForCausalLMNextN] diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 8811c90b2ddc..84ca445c7c5e 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -482,6 +482,31 @@ def __init__( ) +class HunyuanDetector(BaseReasoningFormatDetector): + """ + Detector for Hunyuan models (e.g., tencent/Hunyuan-A13B-Instruct). + + Like Glm45Detector but uses ```` (plural) as the tool start token. + """ + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + tool_start_token="", + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + + class Gemma4Detector(BaseReasoningFormatDetector): """Gemma4 reasoning detector.""" @@ -518,6 +543,7 @@ class ReasoningParser: "deepseek-r1": DeepSeekR1Detector, "deepseek-v3": Qwen3Detector, "glm45": Glm45Detector, + "hunyuan": HunyuanDetector, "gpt-oss": GptOssDetector, "kimi": KimiDetector, "kimi_k2": KimiK2Detector, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 56b9565db63c..61db4c83e801 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3399,6 +3399,7 @@ def _handle_speculative_decoding(self): "BailingMoeV2_5ForCausalLM", "MistralLarge3ForCausalLM", "PixtralForConditionalGeneration", + "HYV3ForCausalLM", ]: if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path diff --git a/test/registered/unit/entrypoints/openai/test_serving_chat.py b/test/registered/unit/entrypoints/openai/test_serving_chat.py index 884a221a8fab..1fbc8bad7cdc 100644 --- a/test/registered/unit/entrypoints/openai/test_serving_chat.py +++ b/test/registered/unit/entrypoints/openai/test_serving_chat.py @@ -900,6 +900,25 @@ def test_extract_routed_dp_rank_from_header_invalid(self): self.assertEqual(context.exception.status_code, 400) self.assertIn("must be an integer", context.exception.detail) + def test_hunyuan_reasoning_effort_dispatch(self): + tm = _MockTokenizerManager() + tm.server_args.reasoning_parser = "hunyuan" + chat = OpenAIServingChat(tm, _MockTemplateManager()) + req = ChatCompletionRequest( + model="x", messages=[{"role": "user", "content": "hi"}] + ) + cases = [ + ("no_think", False), + ("none", False), + (None, False), + ("high", True), + ("low", True), + ] + for effort, expected in cases: + with self.subTest(effort=effort): + req.reasoning_effort = effort + self.assertEqual(chat._get_reasoning_from_request(req), expected) + class TestProcessToolCallsWithRequiredToolChoice(unittest.TestCase): """Test _process_tool_calls with tool_choice='required' uses model-specific parser.""" diff --git a/test/registered/unit/function_call/test_hunyuan_detector.py b/test/registered/unit/function_call/test_hunyuan_detector.py new file mode 100644 index 000000000000..61bef8f016f0 --- /dev/null +++ b/test/registered/unit/function_call/test_hunyuan_detector.py @@ -0,0 +1,733 @@ +"""Unit tests for HunyuanDetector - no server, no model loading.""" + +import json +import unittest + +from sglang.srt.entrypoints.openai.protocol import Function, Tool +from sglang.srt.function_call.hunyuan_detector import HunyuanDetector +from sglang.test.ci.ci_register import register_cpu_ci +from sglang.test.test_utils import CustomTestCase + +register_cpu_ci(est_time=1, suite="stage-a-test-cpu") + + +def _make_tools(): + return [ + Tool( + type="function", + function=Function( + name="get_current_date", + description="Get the current date", + parameters={}, + ), + ), + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date"}, + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="search", + description="Search the web", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query", + }, + "count": { + "type": "integer", + "description": "Number of results", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="calculate", + description="Calculate expression", + parameters={ + "type": "object", + "properties": { + "expression": {"type": "string"}, + "precision": {"type": "number"}, + "verbose": {"type": "boolean"}, + }, + }, + ), + ), + ] + + +class TestHunyuanDetectorHasToolCall(CustomTestCase): + def setUp(self): + self.detector = HunyuanDetector() + + def test_has_tool_call_true(self): + text = ( + "get_current_date" + ) + self.assertTrue(self.detector.has_tool_call(text)) + + def test_has_tool_call_false(self): + self.assertFalse( + self.detector.has_tool_call("The weather in Beijing is sunny today.") + ) + + def test_has_tool_call_partial_tag(self): + self.assertFalse(self.detector.has_tool_call("")) + self.assertFalse(self.detector.has_tool_call(" text after") + ) + + +class TestHunyuanDetectorDetectAndParse(CustomTestCase): + def setUp(self): + self.tools = _make_tools() + self.detector = HunyuanDetector() + + def test_no_tool_call(self): + text = "This is a plain response." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + self.assertEqual(result.normal_text, text) + + def test_zero_arg_inline(self): + text = ( + "get_current_date" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_date") + self.assertEqual(json.loads(result.calls[0].parameters), {}) + + def test_zero_arg_newline(self): + text = ( + "\n" + "get_current_date\n" + "\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_current_date") + + def test_single_string_arg(self): + text = ( + "get_weather" + "cityBeijing" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args, {"city": "Beijing"}) + + def test_multiple_args_same_line(self): + text = ( + "get_weather" + "cityBeijing" + "date2026-03-30" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["city"], "Beijing") + self.assertEqual(args["date"], "2026-03-30") + + def test_args_with_newlines(self): + text = ( + "\n" + "get_weather\n" + "city\n" + "Beijing\n" + "date\n" + "2026-03-30\n" + "\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["city"], "Beijing") + self.assertEqual(args["date"], "2026-03-30") + + def test_content_before_tool_call(self): + text = ( + "Checking." + "\n" + "get_current_date\n" + "\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.normal_text, "Checking.") + + def test_multiple_tool_calls(self): + text = ( + "\n" + "get_weather\n" + "city\nBeijing\n" + "\n" + "get_weather\n" + "city\nHangzhou\n" + "\n" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(json.loads(result.calls[0].parameters)["city"], "Beijing") + self.assertEqual(json.loads(result.calls[1].parameters)["city"], "Hangzhou") + + def test_empty_content_returns_empty_normal_text(self): + text = "\nget_current_date\n\n" + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.normal_text, "") + + def test_unknown_tool_skipped(self): + text = ( + "nonexistent_func" + "x1" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 0) + + def test_mixed_known_and_unknown_tools(self): + """Known tools should be parsed, unknown ones skipped.""" + text = ( + "" + "get_current_date" + "nonexistent" + "search" + "querytest" + "" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_current_date") + self.assertEqual(result.calls[1].name, "search") + + def test_three_parallel_tool_calls(self): + text = ( + "" + "get_weather" + "cityBeijing" + "" + "get_weather" + "cityTokyo" + "" + "get_current_date" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 3) + self.assertEqual(result.calls[0].name, "get_weather") + self.assertEqual(result.calls[1].name, "get_weather") + self.assertEqual(result.calls[2].name, "get_current_date") + # tool_index maps to position in tools list + self.assertEqual(result.calls[0].tool_index, 1) # get_weather is index 1 + self.assertEqual(result.calls[2].tool_index, 0) # get_current_date is index 0 + + +class TestHunyuanDetectorArgDeserialization(CustomTestCase): + """Test type-aware argument deserialization.""" + + def setUp(self): + self.tools = _make_tools() + self.detector = HunyuanDetector() + + def test_integer_arg(self): + text = ( + "search" + "queryrestaurants" + "count5" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["query"], "restaurants") + self.assertEqual(args["count"], 5) + self.assertIsInstance(args["count"], int) + + def test_float_arg(self): + text = ( + "calculate" + "expression1+1" + "precision0.01" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["expression"], "1+1") + self.assertAlmostEqual(args["precision"], 0.01) + + def test_boolean_arg(self): + text = ( + "calculate" + "expression2+2" + "verbosetrue" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertIs(args["verbose"], True) + + def test_string_arg_not_deserialized(self): + """String-typed args should stay as strings even if they look like JSON.""" + text = ( + "search" + 'query{"key": "value"}' + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["query"], '{"key": "value"}') + self.assertIsInstance(args["query"], str) + + def test_non_json_value_stays_string(self): + """Non-JSON-parseable values for non-string types should fall back to string.""" + text = ( + "search" + "queryhello world" + "countnot a number" + "" + ) + result = self.detector.detect_and_parse(text, self.tools) + args = json.loads(result.calls[0].parameters) + self.assertEqual(args["count"], "not a number") + + +def _collect_streamed_tool_calls(all_calls): + """Accumulate streaming ToolCallItems (name + arg-JSON fragments) by tool_index.""" + tools = {} + for c in all_calls: + idx = c.tool_index + if idx not in tools: + tools[idx] = {"name": c.name or "", "parameters": c.parameters or ""} + else: + if c.name: + tools[idx]["name"] += c.name + if c.parameters: + tools[idx]["parameters"] += c.parameters + return [tools[i] for i in sorted(tools.keys())] + + +class TestHunyuanDetectorStreaming(CustomTestCase): + def setUp(self): + self.tools = _make_tools() + + def _new_detector(self): + return HunyuanDetector() + + def test_normal_text_only(self): + detector = self._new_detector() + result = detector.parse_streaming_increment( + "Hello, I can help you with that.", self.tools + ) + self.assertEqual(result.normal_text, "Hello, I can help you with that.") + self.assertEqual(len(result.calls), 0) + + def test_complete_tool_call_single_chunk(self): + detector = self._new_detector() + text = ( + "" + "get_current_date" + "" + ) + result = detector.parse_streaming_increment(text, self.tools) + collected = _collect_streamed_tool_calls(result.calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_current_date") + self.assertEqual(json.loads(collected[0]["parameters"]), {}) + + def test_chunked_tool_call(self): + detector = self._new_detector() + chunks = [ + "", + "get_weather", + "city", + "Tokyo", + "", + "", + ] + all_calls = [] + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + all_calls.extend(result.calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_weather") + args = json.loads(collected[0]["parameters"]) + self.assertEqual(args["city"], "Tokyo") + + def test_normal_text_before_tool(self): + detector = self._new_detector() + r1 = detector.parse_streaming_increment("Let me check. ", self.tools) + self.assertIn("Let me check.", r1.normal_text) + + r2 = detector.parse_streaming_increment( + "get_current_date", + self.tools, + ) + collected = _collect_streamed_tool_calls(r2.calls) + self.assertEqual([c["name"] for c in collected], ["get_current_date"]) + + def test_multiple_tool_calls_chunked(self): + detector = self._new_detector() + chunks = [ + "\n", + "get_weather\n", + "cityBeijing\n", + "\n", + "get_weather\n", + "cityTokyo\n", + "\n", + "", + ] + all_calls = [] + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + all_calls.extend(result.calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 2) + self.assertEqual(json.loads(collected[0]["parameters"])["city"], "Beijing") + self.assertEqual(json.loads(collected[1]["parameters"])["city"], "Tokyo") + + def test_partial_bot_token_buffered(self): + """Partial at end of chunk should be buffered, not emitted.""" + detector = self._new_detector() + r1 = detector.parse_streaming_increment("Hello get_current_date" + ) + all_calls = [] + for ch in full: + result = detector.parse_streaming_increment(ch, self.tools) + all_calls.extend(result.calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_current_date") + self.assertEqual(json.loads(collected[0]["parameters"]), {}) + + def test_streaming_with_args_char_by_char(self): + detector = self._new_detector() + full = ( + "get_weather" + "cityNYC" + "" + ) + all_calls = [] + for ch in full: + result = detector.parse_streaming_increment(ch, self.tools) + all_calls.extend(result.calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 1) + args = json.loads(collected[0]["parameters"]) + self.assertEqual(args["city"], "NYC") + + def test_streaming_three_tools_sequential(self): + """Three different tool calls arriving sequentially.""" + detector = self._new_detector() + chunks = [ + "", + "get_current_date", + "get_weathercitySF", + "searchquerytest", + "", + ] + all_calls = [] + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + all_calls.extend(result.calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 3) + self.assertEqual(collected[0]["name"], "get_current_date") + self.assertEqual(collected[1]["name"], "get_weather") + self.assertEqual(collected[2]["name"], "search") + # Streaming uses sequential tool_index (0, 1, 2) + self.assertEqual(sorted({c.tool_index for c in all_calls}), [0, 1, 2]) + + def test_streaming_normal_text_not_lost(self): + """All normal text before tool_calls should be fully emitted.""" + detector = self._new_detector() + all_normal = "" + for chunk in ["I will ", "check the ", "date now. "]: + result = detector.parse_streaming_increment(chunk, self.tools) + all_normal += result.normal_text + + result = detector.parse_streaming_increment( + "get_current_date", + self.tools, + ) + all_normal += result.normal_text + self.assertIn("I will check the date now.", all_normal) + + def test_streaming_name_comes_before_args(self): + """The name delta must arrive before any arg deltas (two-phase contract).""" + detector = self._new_detector() + text = ( + "get_weather" + "cityParis" + "" + ) + all_calls = [] + for ch in text: + all_calls.extend(detector.parse_streaming_increment(ch, self.tools).calls) + + name_indices = [i for i, c in enumerate(all_calls) if c.name] + param_indices = [i for i, c in enumerate(all_calls) if c.parameters] + self.assertTrue(name_indices, "expected at least one name delta") + self.assertTrue(param_indices, "expected at least one arg delta") + self.assertLess(min(name_indices), min(param_indices)) + + def test_streaming_typed_args_coerced(self): + """Streaming must apply schema-aware type coercion (int/float/bool).""" + detector = self._new_detector() + chunks = [ + "", + "search", + "querypizza", + "count7", + "", + ] + all_calls = [] + for chunk in chunks: + all_calls.extend( + detector.parse_streaming_increment(chunk, self.tools).calls + ) + collected = _collect_streamed_tool_calls(all_calls) + args = json.loads(collected[0]["parameters"]) + self.assertEqual(args["query"], "pizza") + self.assertEqual(args["count"], 7) + self.assertIsInstance(args["count"], int) + + def test_streaming_string_arg_holds_back_partial_end_tag(self): + """Char-by-char string streaming must not leak `` into the value.""" + detector = self._new_detector() + full = ( + "get_weather" + "citySan Francisco" + "" + ) + all_calls = [] + for ch in full: + all_calls.extend(detector.parse_streaming_increment(ch, self.tools).calls) + + collected = _collect_streamed_tool_calls(all_calls) + args = json.loads(collected[0]["parameters"]) + self.assertEqual(args["city"], "San Francisco") + + def test_streaming_all_in_one_delta(self): + """Entire tool call arriving in a single delta.""" + detector = self._new_detector() + text = ( + "\nget_current_date\n" + "\n" + ) + result = detector.parse_streaming_increment(text, self.tools) + collected = _collect_streamed_tool_calls(result.calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_current_date") + self.assertEqual(json.loads(collected[0]["parameters"]), {}) + + def test_streaming_content_before(self): + """Normal text preceding a tool call must be surfaced.""" + detector = self._new_detector() + deltas = [ + "Checking.", + "", + "\n", + "get_current_date", + "", + "\n", + "\n", + ] + all_calls = [] + all_normal = "" + for d in deltas: + r = detector.parse_streaming_increment(d, self.tools) + all_calls.extend(r.calls) + all_normal += r.normal_text + self.assertIn("Checking.", all_normal) + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_current_date") + + +class TestHunyuanDetectorStructureInfo(CustomTestCase): + def setUp(self): + self.detector = HunyuanDetector() + + def test_structure_info_content(self): + info_fn = self.detector.structure_info() + info = info_fn("get_weather") + self.assertIn("get_weather", info.begin) + self.assertIn("", info.begin) + self.assertIn("", info.begin) + self.assertIn("", info.end) + self.assertEqual(info.trigger, "") + + def test_supports_structural_tag(self): + self.assertFalse(self.detector.supports_structural_tag()) + + +class TestHunyuanDetectorAccuracy(CustomTestCase): + """Accuracy tests for realistic HYV3 output patterns.""" + + def setUp(self): + self.tools = _make_tools() + self.detector = HunyuanDetector() + + def test_reference_zero_arg_inline(self): + out = ( + "get_current_date" + ) + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 1) + self.assertEqual(r.calls[0].name, "get_current_date") + self.assertEqual(json.loads(r.calls[0].parameters), {}) + self.assertEqual(r.normal_text, "") + + def test_reference_zero_arg_newline(self): + out = "\nget_current_date\n\n" + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 1) + self.assertEqual(r.calls[0].name, "get_current_date") + + def test_reference_args_same_line(self): + out = ( + "get_weathercityBeijing" + "date2026-03-30" + ) + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 1) + args = json.loads(r.calls[0].parameters) + self.assertEqual(args, {"city": "Beijing", "date": "2026-03-30"}) + + def test_reference_args_with_newlines(self): + out = ( + "\nget_weather\ncity\nBeijing" + "\ndate\n2026-03-30\n\n" + ) + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 1) + args = json.loads(r.calls[0].parameters) + self.assertEqual(args, {"city": "Beijing", "date": "2026-03-30"}) + + def test_reference_content_before(self): + out = "Checking.\nget_current_date\n\n" + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 1) + self.assertEqual(r.normal_text, "Checking.") + + def test_reference_multiple(self): + out = ( + "\nget_weather\ncity\nBeijing" + "\ndate\n2026-03-30\n\n" + "get_weather\ncity\nHangzhou\n" + "date\n2026-03-30\n\n" + ) + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 2) + + def test_reference_empty_content_none(self): + out = "\nget_current_date\n\n" + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(r.normal_text, "") + + def test_reference_no_tool_call(self): + out = "This is a plain response." + r = self.detector.detect_and_parse(out, self.tools) + self.assertEqual(len(r.calls), 0) + self.assertEqual(r.normal_text, out) + + +class TestHunyuanDetectorFunctionCallParser(CustomTestCase): + """Test through the FunctionCallParser interface.""" + + def setUp(self): + self.tools = _make_tools() + + def test_parser_registry(self): + from sglang.srt.function_call.function_call_parser import FunctionCallParser + + parser = FunctionCallParser(self.tools, "hunyuan") + self.assertIsInstance(parser.detector, HunyuanDetector) + + def test_parse_non_stream(self): + from sglang.srt.function_call.function_call_parser import FunctionCallParser + + parser = FunctionCallParser(self.tools, "hunyuan") + text = ( + "Checking.get_weather" + "cityTokyo" + "" + ) + normal, calls = parser.parse_non_stream(text) + self.assertEqual(normal, "Checking.") + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0].name, "get_weather") + self.assertEqual(json.loads(calls[0].parameters)["city"], "Tokyo") + + def test_parse_stream_chunks(self): + from sglang.srt.function_call.function_call_parser import FunctionCallParser + + parser = FunctionCallParser(self.tools, "hunyuan") + chunks = [ + "", + "get_current_date", + "", + ] + all_calls = [] + for chunk in chunks: + normal, calls = parser.parse_stream_chunk(chunk) + all_calls.extend(calls) + + collected = _collect_streamed_tool_calls(all_calls) + self.assertEqual(len(collected), 1) + self.assertEqual(collected[0]["name"], "get_current_date") + self.assertEqual(json.loads(collected[0]["parameters"]), {}) + + def test_has_tool_call_through_parser(self): + from sglang.srt.function_call.function_call_parser import FunctionCallParser + + parser = FunctionCallParser(self.tools, "hunyuan") + self.assertTrue(parser.has_tool_call("foo")) + self.assertFalse(parser.has_tool_call("no tools here")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/unit/parser/test_reasoning_parser.py b/test/registered/unit/parser/test_reasoning_parser.py index 92ad536a872c..f33c88414679 100644 --- a/test/registered/unit/parser/test_reasoning_parser.py +++ b/test/registered/unit/parser/test_reasoning_parser.py @@ -7,6 +7,7 @@ DeepSeekR1Detector, Gemma4Detector, Glm45Detector, + HunyuanDetector, KimiDetector, KimiK2Detector, Nemotron3Detector, @@ -519,6 +520,99 @@ def test_forced_reasoning_mode(self): self.assertEqual(result.normal_text, "tool call") +class TestHunyuanDetector(CustomTestCase): + """Test cases for Hunyuan detector with tool interruption support.""" + + def setUp(self): + self.detector = HunyuanDetector() + + def test_init(self): + """Test HunyuanDetector initialization.""" + self.assertEqual(self.detector.think_start_token, "") + self.assertEqual(self.detector.think_end_token, "") + self.assertEqual(self.detector.tool_start_token, "") + self.assertFalse(self.detector._in_reasoning) + self.assertTrue(self.detector.stream_reasoning) + + def test_detect_and_parse_normal_reasoning(self): + """Test parsing normal reasoning block without tool interruption.""" + text = "Let me think about thisThe answer is 42." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "Let me think about this") + self.assertEqual(result.normal_text, "The answer is 42.") + + def test_detect_and_parse_without_thinking(self): + """Test parsing without thinking tokens (no_think mode).""" + text = "Direct answer without thinking." + result = self.detector.detect_and_parse(text) + self.assertEqual(result.normal_text, text) + self.assertEqual(result.reasoning_text, "") + + def test_detect_and_parse_tool_interrupt(self): + """Test parsing with tool call interruption during reasoning.""" + text = "I need to checkget_weather" + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "I need to check") + self.assertIn("", result.normal_text) + + def test_streaming_normal_reasoning(self): + """Test streaming parse of normal reasoning block.""" + self.detector.parse_streaming_increment("") + result1 = self.detector.parse_streaming_increment("reasoning content") + self.assertEqual(result1.reasoning_text, "reasoning content") + + result2 = self.detector.parse_streaming_increment("answer") + self.assertEqual(result2.normal_text, "answer") + self.assertFalse(self.detector._in_reasoning) + + def test_streaming_tool_interrupt(self): + """Test streaming parse interrupted by tool call section.""" + self.detector.parse_streaming_increment("") + result1 = self.detector.parse_streaming_increment("thinking") + self.assertEqual(result1.reasoning_text, "thinking") + + result2 = self.detector.parse_streaming_increment("") + self.assertEqual(result2.reasoning_text, "") + self.assertEqual(result2.normal_text, "") + self.assertFalse(self.detector._in_reasoning) + + def test_streaming_after_interrupt_is_normal(self): + """After tool interruption, subsequent chunks should be normal text.""" + self.detector.parse_streaming_increment("") + self.detector.parse_streaming_increment("reasoning") + result = self.detector.parse_streaming_increment("data") + self.assertEqual(result.reasoning_text, "") + self.assertEqual(result.normal_text, "data") + + def test_reasoning_parser_integration(self): + """Test Hunyuan through ReasoningParser API.""" + parser = ReasoningParser("hunyuan") + self.assertIsInstance(parser.detector, HunyuanDetector) + + # Non-streaming + reasoning, normal = parser.parse_non_stream( + "thinkingfunc" + ) + self.assertEqual(reasoning, "thinking") + self.assertIn("", normal) + + def test_reasoning_parser_streaming(self): + """Test Hunyuan streaming through ReasoningParser API.""" + parser = ReasoningParser("hunyuan") + chunks = ["", "reasoning", "", "func"] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + reasoning, normal = parser.parse_stream_chunk(chunk) + if reasoning: + all_reasoning += reasoning + if normal: + all_normal += normal + + self.assertEqual(all_reasoning, "reasoning") + self.assertIn("", all_normal) + + class TestNemotron3Detector(CustomTestCase): def setUp(self): self.detector = Nemotron3Detector() @@ -740,6 +834,9 @@ def test_init_valid_model(self): parser = ReasoningParser("glm45") self.assertIsInstance(parser.detector, Glm45Detector) + parser = ReasoningParser("hunyuan") + self.assertIsInstance(parser.detector, HunyuanDetector) + parser = ReasoningParser("gemma4") self.assertIsInstance(parser.detector, Gemma4Detector)