From 5c8f679ac9c0ab270cbb0d98f274492f12830e8a Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Sat, 9 May 2026 08:02:05 +0200 Subject: [PATCH 1/7] Optimize mhc_pre by folding stage-1 reduction into big_fuse Co-Authored-By: Cheng Wan --- python/sglang/srt/layers/mhc.py | 45 +++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py index 1c27636efb5c..8eb7ac0910b8 100644 --- a/python/sglang/srt/layers/mhc.py +++ b/python/sglang/srt/layers/mhc.py @@ -138,12 +138,15 @@ def mhc_pre_big_fuse_tilelang( sinkhorn_repeat: int, n_splits: int = 16, hc_mult: int = 4, + gemm_last_dim: int = -1, ): num_tokens = T.dynamic("num_tokens") hc_mult3 = hc_mult * (2 + hc_mult) + if gemm_last_dim < 0: + gemm_last_dim = hc_mult3 hidden_block = math.gcd(512, hidden_size) - gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_mul: T.Tensor[[n_splits, num_tokens, gemm_last_dim], T.float32] gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] hc_scale: T.Tensor[[3], T.float32] hc_base: T.Tensor[[hc_mult3], T.float32] @@ -484,13 +487,6 @@ def mhc_pre( num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device ) - gemm_out_mul = torch.empty( - n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device - ) - gemm_out_sqrsum = torch.empty( - n_splits, num_tokens, dtype=torch.float32, device=residual.device - ) - if num_tokens <= 2048: assert n_splits == 1 if hc_hidden_size == 16384: @@ -502,29 +498,37 @@ def mhc_pre( f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " f"got {hc_hidden_size}" ) - kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( + kernel_0, _ = mhc_pre_gemm_sqrsum_splitk_kernel( hc_mult3, hc_hidden_size, split_k=n_splits_pre, token_block=32, hidden_block=hidden_block, ) - partial_out = gemm_out_mul.new_empty(n_splits_pre, num_tokens, 32) - partial_sqrsum = gemm_out_sqrsum.new_empty(n_splits_pre, num_tokens) + partial_out = torch.empty( + n_splits_pre, num_tokens, 32, dtype=torch.float32, device=residual.device + ) + partial_sqrsum = torch.empty( + n_splits_pre, num_tokens, dtype=torch.float32, device=residual.device + ) kernel_0( residual_flat.view(num_tokens, hc_hidden_size), fn_flat, partial_out, partial_sqrsum, ) - kernel_1( - partial_out, - partial_sqrsum, - gemm_out_mul.squeeze(0), - gemm_out_sqrsum.squeeze(0), - ) - del partial_out, partial_sqrsum + # Stage_1 reduction is folded into big_fuse below; skip launching it. + gemm_out_mul = partial_out + gemm_out_sqrsum = partial_sqrsum + gemm_last_dim = 32 + big_fuse_n_splits = n_splits_pre else: + gemm_out_mul = torch.empty( + n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device + ) + gemm_out_sqrsum = torch.empty( + n_splits, num_tokens, dtype=torch.float32, device=residual.device + ) assert ( n_splits == 1 ), "The simple TileLang version gemm_sqrsum doesn't support split-k" @@ -536,6 +540,8 @@ def mhc_pre( hc_mult3, hc_mult * hidden_size, ) + gemm_last_dim = hc_mult3 + big_fuse_n_splits = n_splits mhc_pre_big_fuse_tilelang( gemm_out_mul, @@ -552,8 +558,9 @@ def mhc_pre( hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, - n_splits, + big_fuse_n_splits, hc_mult, + gemm_last_dim, ) post_mix = post_mix.view(*outer_shape, hc_mult, 1) From 4c9517307af2c824baa43898f1d12e4236bbb140 Mon Sep 17 00:00:00 2001 From: Chunan Zeng Date: Sat, 9 May 2026 08:39:15 +0200 Subject: [PATCH 2/7] Use DeepGemm for mhc_pre GEMM when available Co-authored-by: Chunan Zeng --- python/sglang/srt/layers/mhc.py | 129 +++++++++++++++++++++----------- 1 file changed, 84 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py index 8eb7ac0910b8..5fd8e73fd193 100644 --- a/python/sglang/srt/layers/mhc.py +++ b/python/sglang/srt/layers/mhc.py @@ -7,6 +7,7 @@ import torch from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.environ import envs from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split from sglang.srt.layers.utils.common import strict_contiguous @@ -441,6 +442,14 @@ def mhc_pre_gemm_sqrsum_splitk_stage_1( ) +def _compute_num_split_for_mhc_pre(num_tokens: int, hc_hidden_size: int) -> int: + block_m, block_k = 64, 64 + grid_size = (num_tokens + block_m - 1) // block_m + num_block_k = (hc_hidden_size + block_k - 1) // block_k + n_sms = torch.cuda.get_device_properties(0).multi_processor_count + return max(1, min(n_sms // max(grid_size, 1), num_block_k // 4)) + + def mhc_pre( residual: torch.Tensor, fn: torch.Tensor, @@ -487,61 +496,91 @@ def mhc_pre( num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device ) - if num_tokens <= 2048: - assert n_splits == 1 - if hc_hidden_size == 16384: - hidden_block = 256 - elif hc_hidden_size == 28672: - hidden_block = 128 - else: - raise NotImplementedError( - f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " - f"got {hc_hidden_size}" - ) - kernel_0, _ = mhc_pre_gemm_sqrsum_splitk_kernel( - hc_mult3, - hc_hidden_size, - split_k=n_splits_pre, - token_block=32, - hidden_block=hidden_block, - ) - partial_out = torch.empty( - n_splits_pre, num_tokens, 32, dtype=torch.float32, device=residual.device - ) - partial_sqrsum = torch.empty( - n_splits_pre, num_tokens, dtype=torch.float32, device=residual.device - ) - kernel_0( - residual_flat.view(num_tokens, hc_hidden_size), - fn_flat, - partial_out, - partial_sqrsum, - ) - # Stage_1 reduction is folded into big_fuse below; skip launching it. - gemm_out_mul = partial_out - gemm_out_sqrsum = partial_sqrsum - gemm_last_dim = 32 - big_fuse_n_splits = n_splits_pre - else: + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): + import deep_gemm + + n_splits = _compute_num_split_for_mhc_pre(num_tokens, hc_hidden_size) + gemm_out_mul = torch.empty( n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device ) gemm_out_sqrsum = torch.empty( n_splits, num_tokens, dtype=torch.float32, device=residual.device ) - assert ( - n_splits == 1 - ), "The simple TileLang version gemm_sqrsum doesn't support split-k" - mhc_pre_gemm_sqrsum_tilelang( - residual_flat.view(num_tokens, hc_mult * hidden_size), + + deep_gemm.tf32_hc_prenorm_gemm( + residual_flat.view(num_tokens, hc_hidden_size), fn_flat, - gemm_out_mul.squeeze(0), - gemm_out_sqrsum.squeeze(0), - hc_mult3, - hc_mult * hidden_size, + gemm_out_mul, + gemm_out_sqrsum, + num_splits=n_splits, ) gemm_last_dim = hc_mult3 big_fuse_n_splits = n_splits + else: + if num_tokens <= 2048: + assert n_splits == 1 + if hc_hidden_size == 16384: + hidden_block = 256 + elif hc_hidden_size == 28672: + hidden_block = 128 + else: + raise NotImplementedError( + f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " + f"got {hc_hidden_size}" + ) + kernel_0, _ = mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3, + hc_hidden_size, + split_k=n_splits_pre, + token_block=32, + hidden_block=hidden_block, + ) + partial_out = torch.empty( + n_splits_pre, + num_tokens, + 32, + dtype=torch.float32, + device=residual.device, + ) + partial_sqrsum = torch.empty( + n_splits_pre, num_tokens, dtype=torch.float32, device=residual.device + ) + kernel_0( + residual_flat.view(num_tokens, hc_hidden_size), + fn_flat, + partial_out, + partial_sqrsum, + ) + # Stage_1 reduction is folded into big_fuse below; skip launching it. + gemm_out_mul = partial_out + gemm_out_sqrsum = partial_sqrsum + gemm_last_dim = 32 + big_fuse_n_splits = n_splits_pre + else: + gemm_out_mul = torch.empty( + n_splits, + num_tokens, + hc_mult3, + dtype=torch.float32, + device=residual.device, + ) + gemm_out_sqrsum = torch.empty( + n_splits, num_tokens, dtype=torch.float32, device=residual.device + ) + assert ( + n_splits == 1 + ), "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + gemm_last_dim = hc_mult3 + big_fuse_n_splits = n_splits mhc_pre_big_fuse_tilelang( gemm_out_mul, From cfa798ba3bc79f8b76c2db068a332dcfc6684d55 Mon Sep 17 00:00:00 2001 From: Cheng Wan Date: Sat, 9 May 2026 08:44:39 +0200 Subject: [PATCH 3/7] Fuse RMSNorm into mhc_pre big_fuse kernel Co-authored-by: Cheng Wan --- python/sglang/srt/layers/mhc.py | 246 ++++++++++++++++++++++-- python/sglang/srt/models/deepseek_v4.py | 38 +++- 2 files changed, 256 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py index 5fd8e73fd193..e268d553bf08 100644 --- a/python/sglang/srt/layers/mhc.py +++ b/python/sglang/srt/layers/mhc.py @@ -450,6 +450,177 @@ def _compute_num_split_for_mhc_pre(num_tokens: int, hc_hidden_size: int) -> int: return max(1, min(n_sms // max(grid_size, 1), num_block_k // 4)) +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_with_norm_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + norm_weight, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + norm_eps: float, + n_splits: int = 16, + hc_mult: int = 4, + gemm_last_dim: int = -1, +): + """Fused mhc_pre big_fuse + RMSNorm of layer_input. + + Identical to mhc_pre_big_fuse_tilelang for the (post_mix, comb_mix) path. + For the layer_input path, the weighted-sum result is stashed in shared + memory while accumulating sum_sq, then a second pipelined sweep applies + rsqrt(sum_sq/D + norm_eps) * norm_weight before writing to HBM. + """ + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + if gemm_last_dim < 0: + gemm_last_dim = hc_mult3 + hidden_block = math.gcd(1024, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, gemm_last_dim], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + norm_weight: T.Tensor[[hidden_size], T.bfloat16] + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(num_tokens, threads=96) as i: + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + + if ENABLE_PDL: + T.pdl_sync() + + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = ( + T.sigmoid( + mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] + ) + * hc_post_mult_value + ) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = ( + mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + + hc_base[j * hc_mult + k + hc_mult * 2] + ) + + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + + # Stash unnormalized weighted-sum output in shared memory as bf16 + # (matches the rounding the reference path does when RMSNorm reads bf16). + output_shared = T.alloc_shared(hidden_size, T.bfloat16) + sumsq_per_pos = T.alloc_fragment(hidden_block, T.float32) + + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=3): + xs = T.alloc_shared((hc_mult, hidden_block), T.bfloat16) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + for i1_h in T.Parallel(hidden_block): + sumsq_per_pos[i1_h] += ol[i1_h] * ol[i1_h] + output_shared[i0_h * hidden_block + i1_h] = T.bfloat16(ol[i1_h]) + + sumsq = T.alloc_fragment(1, T.float32) + T.reduce_sum(sumsq_per_pos, sumsq, dim=0) + rsqrt_norm = T.alloc_fragment(1, T.float32) + rsqrt_norm[0] = T.rsqrt(sumsq[0] / hidden_size + norm_eps) + + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + w_shared = T.alloc_shared(hidden_block, T.bfloat16) + w_local = T.alloc_fragment(hidden_block, T.float32) + T.copy(norm_weight[i0_h * hidden_block], w_shared) + T.copy(w_shared, w_local) + + ol = T.alloc_fragment(hidden_block, T.float32) + for i1_h in T.Parallel(hidden_block): + ol[i1_h] = ( + output_shared[i0_h * hidden_block + i1_h] + * rsqrt_norm[0] + * w_local[i1_h] + ) + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + if ENABLE_PDL: + T.pdl_trigger() + + def mhc_pre( residual: torch.Tensor, fn: torch.Tensor, @@ -462,6 +633,9 @@ def mhc_pre( sinkhorn_repeat: int, n_splits: int = 1, n_splits_pre: int = 32, + *, + norm_weight: torch.Tensor | None = None, + norm_eps: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert residual.dtype == torch.bfloat16 @@ -582,25 +756,59 @@ def mhc_pre( gemm_last_dim = hc_mult3 big_fuse_n_splits = n_splits - mhc_pre_big_fuse_tilelang( - gemm_out_mul, - gemm_out_sqrsum, - hc_scale, - hc_base, - residual_flat, - post_mix, - comb_mix, - layer_input, - hidden_size, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - big_fuse_n_splits, - hc_mult, - gemm_last_dim, - ) + if norm_weight is not None: + assert norm_eps is not None, "norm_eps required when norm_weight is provided" + assert norm_weight.shape == ( + hidden_size, + ), f"norm_weight shape {tuple(norm_weight.shape)} != (hidden_size={hidden_size},)" + norm_weight_bf = ( + norm_weight.bfloat16() + if norm_weight.dtype != torch.bfloat16 + else norm_weight + ) + if not norm_weight_bf.is_contiguous(): + norm_weight_bf = norm_weight_bf.contiguous() + mhc_pre_big_fuse_with_norm_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + norm_weight_bf, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + norm_eps, + big_fuse_n_splits, + hc_mult, + gemm_last_dim, + ) + else: + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + big_fuse_n_splits, + hc_mult, + gemm_last_dim, + ) post_mix = post_mix.view(*outer_shape, hc_mult, 1) comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index b1c225051967..1b8f023868cb 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -653,7 +653,11 @@ def hc_pre( hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, + norm=None, ): + """If *norm* is given and the TileLang path is active, the returned + hidden_states are already post-norm (the norm is fused into the kernel).""" + @compile_in_capture_mode def hc_pre_torch_impl(x, hc_fn): x_flat = x.flatten(1).float() @@ -671,11 +675,16 @@ def hc_pre_torch_impl(x, hc_fn): comb = torch.empty( (0, self.hc_mult, self.hc_mult), dtype=dtype, device=x.device ) - return y, post, comb + return y, post, comb, False if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get(): from sglang.srt.layers.mhc import mhc_pre + norm_kwargs = {} + if norm is not None: + norm_kwargs["norm_weight"] = norm.weight.data + norm_kwargs["norm_eps"] = norm.variance_epsilon + post, comb, y = mhc_pre( residual=x, fn=hc_fn, @@ -686,8 +695,9 @@ def hc_pre_torch_impl(x, hc_fn): hc_sinkhorn_eps=self.hc_eps, hc_post_mult_value=2.0, sinkhorn_repeat=self.hc_sinkhorn_iters, + **norm_kwargs, ) - return y, post.squeeze(-1), comb + return y, post.squeeze(-1), comb, norm is not None if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): import deep_gemm @@ -717,7 +727,7 @@ def hc_pre_torch_impl(x, hc_fn): self.hc_eps, ) y = (pre.squeeze(1).unsqueeze(-1) * x_flat.view(shape)).sum(dim=1) - return y.to(dtype), post.squeeze(1), comb.squeeze(1) + return y.to(dtype), post.squeeze(1), comb.squeeze(1), False def hc_post( self, @@ -759,10 +769,15 @@ def forward( input_ids_global: torch.Tensor, ) -> torch.Tensor: residual = hidden_states - hidden_states, post, comb = self.hc_pre( - hidden_states, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + hidden_states, post, comb, norm_fused = self.hc_pre( + hidden_states, + self.hc_attn_fn, + self.hc_attn_scale, + self.hc_attn_base, + norm=self.input_layernorm, ) - hidden_states = self.input_layernorm(hidden_states) + if not norm_fused: + hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( x=hidden_states, @@ -772,10 +787,15 @@ def forward( hidden_states = self.hc_post(hidden_states, residual, post, comb) residual = hidden_states - hidden_states, post, comb = self.hc_pre( - hidden_states, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + hidden_states, post, comb, norm_fused = self.hc_pre( + hidden_states, + self.hc_ffn_fn, + self.hc_ffn_scale, + self.hc_ffn_base, + norm=self.post_attention_layernorm, ) - hidden_states = self.post_attention_layernorm(hidden_states) + if not norm_fused: + hidden_states = self.post_attention_layernorm(hidden_states) _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) _use_tp_moe_gather = ( From abeb7f87c0d6f63d9b4fe2c0eaca2bd7df6fab23 Mon Sep 17 00:00:00 2001 From: Cheng Wan Date: Sat, 9 May 2026 08:50:45 +0200 Subject: [PATCH 4/7] Add fused hc_head Triton kernel Co-authored-by: Cheng Wan --- python/sglang/srt/layers/mhc_head.py | 151 ++++++++++++++++++++++++ python/sglang/srt/models/deepseek_v4.py | 11 ++ 2 files changed, 162 insertions(+) create mode 100644 python/sglang/srt/layers/mhc_head.py diff --git a/python/sglang/srt/layers/mhc_head.py b/python/sglang/srt/layers/mhc_head.py new file mode 100644 index 000000000000..0487e401742c --- /dev/null +++ b/python/sglang/srt/layers/mhc_head.py @@ -0,0 +1,151 @@ +"""Fused triton kernel for the DSV4 hc_head LM-head mixer. + +Reference torch implementation (deepseek_v4.py DeepseekV4Model.hc_head): + + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + +Shapes (DSV4-Pro, hc_mult=4, hidden_size=7168 typical): + x : (T, hc_mult, hidden_size) bf16 + hc_fn : (hc_mult, hc_mult * hidden_size) fp32 + scale : (1,) fp32 + base : (hc_mult,) fp32 + out y : (T, hidden_size) bf16 + +This is a one-shot LM-head op (fires once per forward on the last PP rank), so +we use a 1-CTA-per-token design that does two passes over x without split-K. +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _hc_head_kernel( + x_ptr, + fn_ptr, + scale_ptr, + base_ptr, + y_ptr, + hidden_size: tl.constexpr, + HC_MULT: tl.constexpr, + K_TOTAL: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, + norm_eps: tl.constexpr, + hc_eps: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + + # ---------- Pass 1: sum_sq over flattened K dim, plus hc_mult inner products ---------- + sumsq = tl.zeros((), dtype=tl.float32) + mix = tl.zeros((HC_MULT,), dtype=tl.float32) + + x_row = x_ptr + pid * K_TOTAL + m_idx = tl.arange(0, HC_MULT) + + for k_off in tl.range(0, K_TOTAL, BLOCK_K): + k_offs = k_off + tl.arange(0, BLOCK_K) + k_mask = k_offs < K_TOTAL + x_tile = tl.load(x_row + k_offs, mask=k_mask, other=0.0).to(tl.float32) + + sumsq += tl.sum(x_tile * x_tile, axis=0) + + fn_offs = m_idx[:, None] * K_TOTAL + k_offs[None, :] + fn_mask = (m_idx[:, None] < HC_MULT) & k_mask[None, :] + fn_tile = tl.load(fn_ptr + fn_offs, mask=fn_mask, other=0.0) + mix += tl.sum(fn_tile * x_tile[None, :], axis=1) + + rsqrt = tl.rsqrt(sumsq / K_TOTAL + norm_eps) + scale_v = tl.load(scale_ptr).to(tl.float32) + base_v = tl.load(base_ptr + m_idx).to(tl.float32) + + # pre[m] = sigmoid(mix[m] * rsqrt * scale + base[m]) + hc_eps + pre = tl.sigmoid(mix * rsqrt * scale_v + base_v) + hc_eps + + # ---------- Pass 2: y[d] = sum_m pre[m] * x[m, d] for d in range(hidden_size) ---------- + y_row = y_ptr + pid * hidden_size + + for d_off in tl.range(0, hidden_size, BLOCK_D): + d_offs = d_off + tl.arange(0, BLOCK_D) + d_mask = d_offs < hidden_size + + x_offs = m_idx[:, None] * hidden_size + d_offs[None, :] + x_mask = (m_idx[:, None] < HC_MULT) & d_mask[None, :] + x_block = tl.load(x_row + x_offs, mask=x_mask, other=0.0).to(tl.float32) + + y_block = tl.sum(pre[:, None] * x_block, axis=0) + + tl.store(y_row + d_offs, y_block.to(y_ptr.dtype.element_ty), mask=d_mask) + + +def fused_hc_head( + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + norm_eps: float, + hc_eps: float, +) -> torch.Tensor: + """Fused (RMSNorm + Linear + Sigmoid-gate + weighted-sum) for the DSV4 hc_head. + + Args: + x : (T, hc_mult, hidden_size) bf16/fp16, must be contiguous + hc_fn : (hc_mult, hc_mult * hidden_size) fp32, contiguous + hc_scale : (1,) fp32 scalar + hc_base : (hc_mult,) fp32 + norm_eps : RMS epsilon + hc_eps : additive epsilon after sigmoid + + Returns: + y : (T, hidden_size) same dtype as x + """ + assert x.is_contiguous(), "x must be contiguous" + assert hc_fn.is_contiguous(), "hc_fn must be contiguous" + assert hc_scale.dtype == torch.float32 and hc_base.dtype == torch.float32 + assert hc_fn.dtype == torch.float32 + assert x.dim() == 3, f"x must be 3D (T, hc_mult, hidden_size), got {x.shape}" + + T_val, hc_mult, hidden_size = x.shape + assert hc_fn.shape == (hc_mult, hc_mult * hidden_size), ( + f"hc_fn shape {hc_fn.shape} does not match (hc_mult={hc_mult}, " + f"hc_mult*hidden_size={hc_mult * hidden_size})" + ) + assert hc_base.shape == (hc_mult,) + assert hc_scale.numel() == 1 + + y = torch.empty((T_val, hidden_size), dtype=x.dtype, device=x.device) + + if T_val == 0: + return y + + BLOCK_K = 512 + BLOCK_D = 512 + + hc_mult_pow2 = max(1, triton.next_power_of_2(hc_mult)) + + grid = (T_val,) + _hc_head_kernel[grid]( + x, + hc_fn, + hc_scale, + hc_base, + y, + hidden_size=hidden_size, + HC_MULT=hc_mult_pow2, + K_TOTAL=hc_mult * hidden_size, + BLOCK_K=BLOCK_K, + BLOCK_D=BLOCK_D, + norm_eps=norm_eps, + hc_eps=hc_eps, + num_warps=4, + ) + return y diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 1b8f023868cb..562347035a0e 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -904,6 +904,17 @@ def hc_head( hc_scale: torch.Tensor, hc_base: torch.Tensor, ): + if x.numel() > 0: + from sglang.srt.layers.mhc_head import fused_hc_head + + return fused_hc_head( + x.contiguous(), + hc_fn, + hc_scale, + hc_base, + norm_eps=self.norm_eps, + hc_eps=self.hc_eps, + ) shape, dtype = x.size(), x.dtype x = x.flatten(1).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) From 64c1510a9faf0ac667478f7846cb437460f34dfb Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Sat, 9 May 2026 09:13:22 +0200 Subject: [PATCH 5/7] Fix unnecessary divergences from original commits --- python/sglang/srt/layers/mhc_head.py | 8 ++++---- python/sglang/srt/models/deepseek_v4.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/mhc_head.py b/python/sglang/srt/layers/mhc_head.py index 0487e401742c..43de7defbab4 100644 --- a/python/sglang/srt/layers/mhc_head.py +++ b/python/sglang/srt/layers/mhc_head.py @@ -114,7 +114,7 @@ def fused_hc_head( assert hc_fn.dtype == torch.float32 assert x.dim() == 3, f"x must be 3D (T, hc_mult, hidden_size), got {x.shape}" - T_val, hc_mult, hidden_size = x.shape + T, hc_mult, hidden_size = x.shape assert hc_fn.shape == (hc_mult, hc_mult * hidden_size), ( f"hc_fn shape {hc_fn.shape} does not match (hc_mult={hc_mult}, " f"hc_mult*hidden_size={hc_mult * hidden_size})" @@ -122,9 +122,9 @@ def fused_hc_head( assert hc_base.shape == (hc_mult,) assert hc_scale.numel() == 1 - y = torch.empty((T_val, hidden_size), dtype=x.dtype, device=x.device) + y = torch.empty((T, hidden_size), dtype=x.dtype, device=x.device) - if T_val == 0: + if T == 0: return y BLOCK_K = 512 @@ -132,7 +132,7 @@ def fused_hc_head( hc_mult_pow2 = max(1, triton.next_power_of_2(hc_mult)) - grid = (T_val,) + grid = (T,) _hc_head_kernel[grid]( x, hc_fn, diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 562347035a0e..4461a0a279a9 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -653,7 +653,7 @@ def hc_pre( hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, - norm=None, + norm: Optional[nn.Module] = None, ): """If *norm* is given and the TileLang path is active, the returned hidden_states are already post-norm (the norm is fused into the kernel).""" From 4a90172e9e20ac12c0ca1c31929071e9625d1a17 Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Sat, 9 May 2026 21:30:15 +0200 Subject: [PATCH 6/7] Add DSV4 stages to /rerun-stage whitelist --- scripts/ci/utils/slash_command_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/ci/utils/slash_command_handler.py b/scripts/ci/utils/slash_command_handler.py index 0f863597ec71..61da70174636 100644 --- a/scripts/ci/utils/slash_command_handler.py +++ b/scripts/ci/utils/slash_command_handler.py @@ -424,6 +424,8 @@ def handle_rerun_stage( "stage-c-test-8-gpu-h20", "stage-c-test-4-gpu-b200", "stage-c-test-4-gpu-gb200", + "stage-c-test-dsv4-4-gpu-b200", + "stage-c-test-dsv4-8-gpu-h200", "stage-c-test-deepep-4-gpu-h100", "stage-c-test-deepep-8-gpu-h200", "multimodal-gen-test-1-gpu", From 75716163c3589823d768229f5f27d067c3e0e928 Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Sun, 10 May 2026 05:51:58 +0800 Subject: [PATCH 7/7] Fix uninitialized sumsq accumulator in fused norm kernel T.alloc_fragment does not guarantee zero initialization. The sumsq_per_pos accumulator must be explicitly cleared before the pipelined loop to avoid garbage values corrupting the RMSNorm computation, which caused all-zero model output. Co-authored-by: Cheng Wan --- python/sglang/srt/layers/mhc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py index e268d553bf08..0be750f1a038 100644 --- a/python/sglang/srt/layers/mhc.py +++ b/python/sglang/srt/layers/mhc.py @@ -577,6 +577,7 @@ def mhc_pre_big_fuse_with_norm_tilelang( # (matches the rounding the reference path does when RMSNorm reads bf16). output_shared = T.alloc_shared(hidden_size, T.bfloat16) sumsq_per_pos = T.alloc_fragment(hidden_block, T.float32) + T.clear(sumsq_per_pos) for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=3): xs = T.alloc_shared((hc_mult, hidden_block), T.bfloat16)