Skip to content
14 changes: 13 additions & 1 deletion python/sglang/srt/layers/attention/fla/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import triton.language as tl
from einops import rearrange

from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
cdiv,
cpu_has_amx_support,
Expand All @@ -26,6 +27,9 @@
_is_npu = is_npu()
_use_cpu = is_cpu() and cpu_has_amx_support()

# Maximum rows per Triton block for layernorm gated kernel
MAX_ROWS_PER_BLOCK = 4


def rms_norm_ref(
x,
Expand Down Expand Up @@ -166,9 +170,17 @@ def _get_sm_count(device: torch.device) -> int:


def calc_rows_per_block(M: int, device: torch.device) -> int:
# When piecewise cuda graph is enabled, use a constant value to avoid
# torch.compile creating guards on the dynamic batch dimension.
try:
if get_global_server_args().enable_piecewise_cuda_graph:
return MAX_ROWS_PER_BLOCK
except ValueError:
# Global server args not initialized (e.g., in unit tests)
pass
sm_count = _get_sm_count(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a constant value like 128, why it will affect torch compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns rows_per_block which would be consumed by a triton kernel _layer_norm_fwd_1pass_kernel as a tl.constexpr, with different M here, it could get different rows_per_block and trigger torch recompile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think M is a constant during the compilation of a single graph, why would it trigger recompilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes when num_tokens change, basically breaks torch compile guards and triggers a lot of recompilations during capturing all tokens, taking forever for the capture to finish.

rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4)
rows_per_block = min(rows_per_block, MAX_ROWS_PER_BLOCK)
return rows_per_block


Expand Down
68 changes: 61 additions & 7 deletions python/sglang/srt/layers/radix_linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch
from torch import nn

from sglang.srt.compilation.compilation_config import register_split_op
from sglang.srt.compilation.piecewise_context_manager import get_forward_context
from sglang.srt.utils.custom_op import register_custom_op

if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

Expand Down Expand Up @@ -70,10 +74,60 @@ def forward(
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return forward_batch.attn_backend.forward(
layer=self,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
# Output shape from linear attention: (1, seq_len, num_v_heads, head_v_dim)
seq_len = mixed_qkv.shape[0]
output = torch.empty(
(1, seq_len, self.num_v_heads, self.head_v_dim),
dtype=mixed_qkv.dtype,
device=mixed_qkv.device,
)
unified_linear_attention_with_output(
mixed_qkv,
a,
b,
output,
self.layer_id,
)
return output
else:
return forward_batch.attn_backend.forward(
layer=self,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)


@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_linear_attention_with_output(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
output: torch.Tensor,
layer_id: int,
) -> None:
"""
Custom op wrapper for linear attention computation only.
"""
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]

ret = forward_batch.attn_backend.forward(
layer=attention_layer,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)

assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"

output.view(ret.shape).copy_(ret)
return
5 changes: 4 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,7 +2138,10 @@ def init_piecewise_cuda_graphs(self):
elif hasattr(layer, "attn"):
self.attention_layers.append(layer.attn)
elif hasattr(layer, "linear_attn"):
self.attention_layers.append(layer.linear_attn)
if hasattr(layer.linear_attn, "attn"):
self.attention_layers.append(layer.linear_attn.attn)
else:
self.attention_layers.append(layer.linear_attn)
# For InternVL model
elif hasattr(layer, "attention"):
if hasattr(layer.attention, "attn"):
Expand Down
21 changes: 2 additions & 19 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __init__(
prefix=add_prefix("out_proj", prefix),
)

self.linear_attn = RadixLinearAttention(
self.attn = RadixLinearAttention(
layer_id=layer_id,
num_q_heads=self.num_k_heads // self.attn_tp_size,
num_k_heads=self.num_k_heads // self.attn_tp_size,
Expand Down Expand Up @@ -405,23 +405,6 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.empty_like(hidden_states)
gdn_with_output(
hidden_states,
output,
self.layer_id,
)
return output
else:
return self._forward(hidden_states, forward_batch)

def _forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
seq_len, _ = hidden_states.shape
is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()

projected_states_qkvz, projected_states_ba = self._forward_input_proj(
Expand Down Expand Up @@ -460,7 +443,7 @@ def _forward(
lambda x: x.reshape(x.shape[0], -1), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
core_attn_out = self.linear_attn(
core_attn_out = self.attn(
forward_batch,
mixed_qkv=mixed_qkv,
a=a,
Expand Down
6 changes: 0 additions & 6 deletions test/registered/models/test_qwen3_next_models_pcg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""
Qwen3 Next piecewise CUDA graph tests.

DISABLED: See https://github.com/sgl-project/sglang/issues/17039
PCG tests for Qwen3 Next have intermittent failures (5-10% probability).
Investigation ongoing by @YuweiAn.
"""

import unittest
Expand All @@ -22,7 +18,6 @@
register_cuda_ci(
est_time=400,
suite="stage-c-test-4-gpu-h100",
disabled="Intermittent failures, see #17039",
)

QWEN3_NEXT_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct"
Expand All @@ -32,7 +27,6 @@
}


@unittest.skip("Disabled: intermittent failures, see #17039")
class TestQwen3NextPiecewiseCudaGraph(CustomTestCase):

@classmethod
Expand Down
Loading