diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3b9616e2798f..205ed65b4775 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -71,6 +71,13 @@ def _fp8_scaled_mm_abstract(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=No N = mat_b.shape[-1] return mat_a.new_empty((M, N), dtype=out_dtype) + @torch.library.register_fake("sgl_kernel::fp8_blockwise_scaled_mm") + def _fp8_blockwise_scaled_mm_abstract(mat_a, mat_b, scales_a, scales_b, out_dtype): + # mat_a: [M, K], mat_b: [K, N] or [N, K] depending on callsite layout; output is [M, N]. + M = mat_a.shape[-2] + N = mat_b.shape[-1] + return mat_a.new_empty((M, N), dtype=out_dtype) + use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") use_triton_w8a8_fp8_kernel = get_bool_env_var("USE_TRITON_W8A8_FP8_KERNEL") diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index f01225487bf5..1c4f5ea3a215 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -22,9 +22,6 @@ import torch.nn as nn from einops import rearrange -# Model Executor -from sglang.srt.compilation.piecewise_context_manager import get_forward_context - # Configs from sglang.srt.configs.qwen3_5 import ( Qwen3_5Config, @@ -72,7 +69,6 @@ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock # Models -from sglang.srt.models.qwen3_next import gdn_with_output from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration # Utils @@ -253,22 +249,6 @@ def forward( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ): - output = torch.empty_like(hidden_states) - if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: - gdn_with_output( - hidden_states, - output, - self.layer_id, - ) - return output - else: - return self._forward(hidden_states, forward_batch) - - def _forward( - self, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, ): """ Forward pass with three parts: @@ -287,7 +267,7 @@ def _forward( b = b.contiguous() a = a.contiguous() - core_attn_out = self.attn.forward( + core_attn_out = self.attn( forward_batch=forward_batch, mixed_qkv=mixed_qkv, a=a, diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 39a086c986b4..65bc8d35ccee 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -5,8 +5,6 @@ import torch from torch import nn -from sglang.srt.compilation.compilation_config import register_split_op -from sglang.srt.compilation.piecewise_context_manager import get_forward_context from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.distributed import get_pp_group from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder @@ -53,7 +51,6 @@ make_layers, set_weight_attrs, ) -from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -1149,25 +1146,3 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): EntryClass = Qwen3NextForCausalLM - - -@register_custom_op(mutates_args=["output"]) -@register_split_op() -def gdn_with_output( - hidden_states: torch.Tensor, - output: torch.Tensor, - layer_id: int, -) -> None: - context = get_forward_context() - forward_batch = context.forward_batch - attention_layers = context.attention_layers - attention_layer = attention_layers[layer_id] - - ret = attention_layer._forward(hidden_states, forward_batch) - - assert ( - output.numel() == ret.numel() - ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" - - output.view(ret.shape).copy_(ret) - return diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index d641826e3394..cf3e44e1fa5b 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -987,6 +987,7 @@ def get_input_embeddings(self): def should_apply_lora(self, module_name: str) -> bool: return bool(self._lora_pattern.match(module_name)) + @torch.no_grad() def forward( self, input_ids: torch.Tensor,