Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def _fp8_scaled_mm_abstract(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=No
N = mat_b.shape[-1]
return mat_a.new_empty((M, N), dtype=out_dtype)

@torch.library.register_fake("sgl_kernel::fp8_blockwise_scaled_mm")
def _fp8_blockwise_scaled_mm_abstract(mat_a, mat_b, scales_a, scales_b, out_dtype):
# mat_a: [M, K], mat_b: [K, N] or [N, K] depending on callsite layout; output is [M, N].
M = mat_a.shape[-2]
N = mat_b.shape[-1]
return mat_a.new_empty((M, N), dtype=out_dtype)


use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
use_triton_w8a8_fp8_kernel = get_bool_env_var("USE_TRITON_W8A8_FP8_KERNEL")
Expand Down
22 changes: 1 addition & 21 deletions python/sglang/srt/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import torch.nn as nn
from einops import rearrange

# Model Executor
from sglang.srt.compilation.piecewise_context_manager import get_forward_context

# Configs
from sglang.srt.configs.qwen3_5 import (
Qwen3_5Config,
Expand Down Expand Up @@ -72,7 +69,6 @@
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock

# Models
from sglang.srt.models.qwen3_next import gdn_with_output
from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration

# Utils
Expand Down Expand Up @@ -253,22 +249,6 @@ def forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
output = torch.empty_like(hidden_states)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
gdn_with_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this branch?

Copy link
Collaborator Author

@zminglei zminglei Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch was for PCG purpose (most likely copied from previous qwen3_next.py). But now it's not needed anymore since we added split_op inside RadixLinearAttention. It's for the same purpose of this PR #17613

hidden_states,
output,
self.layer_id,
)
return output
else:
return self._forward(hidden_states, forward_batch)

def _forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
"""
Forward pass with three parts:
Expand All @@ -287,7 +267,7 @@ def _forward(
b = b.contiguous()
a = a.contiguous()

core_attn_out = self.attn.forward(
core_attn_out = self.attn(
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
Expand Down
25 changes: 0 additions & 25 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import torch
from torch import nn

from sglang.srt.compilation.compilation_config import register_split_op
from sglang.srt.compilation.piecewise_context_manager import get_forward_context
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
Expand Down Expand Up @@ -53,7 +51,6 @@
make_layers,
set_weight_attrs,
)
from sglang.srt.utils.custom_op import register_custom_op

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
Expand Down Expand Up @@ -1149,25 +1146,3 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None):


EntryClass = Qwen3NextForCausalLM


@register_custom_op(mutates_args=["output"])
@register_split_op()
def gdn_with_output(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
) -> None:
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]

ret = attention_layer._forward(hidden_states, forward_batch)

assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"

output.view(ret.shape).copy_(ret)
return
1 change: 1 addition & 0 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def get_input_embeddings(self):
def should_apply_lora(self, module_name: str) -> bool:
return bool(self._lora_pattern.match(module_name))

@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading