diff --git a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh index e8837af4cd34..a20f48c87c0a 100644 --- a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh +++ b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh @@ -161,7 +161,10 @@ struct CustomAllReducePull : public CustomAllReduceBase { RuntimeCheck(shot == 1 || shot == 2, "Invalid shot count: ", shot); RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported"); RuntimeCheck(is_type(input.dtype()), "Input dtype mismatch"); - RuntimeCheck(std::bit_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); + // ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT + // builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast + // is value-equivalent for pointer-to-integer. + RuntimeCheck(reinterpret_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); RuntimeCheck(m_pull_ctrl.has_value(), "Controller is not initialized"); RuntimeCheck(static_cast(num_items) == num_items_int64, "Number of items exceeds 4G limit"); diff --git a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh index c4523c27eec3..8ca4f9927f3c 100644 --- a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh +++ b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh @@ -229,7 +229,10 @@ struct CustomAllReducePush : public CustomAllReduceBase { RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch"); RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported"); RuntimeCheck(is_type(input.dtype()), "Input dtype mismatch"); - RuntimeCheck(std::bit_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); + // ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT + // builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast + // is value-equivalent for pointer-to-integer. + RuntimeCheck(reinterpret_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized"); RuntimeCheck(shot == 1, "Push all-reduce only supports 1-shot, got: ", shot); RuntimeCheck(static_cast(num_items) == num_items_int64, "Number of items exceeds 4G limit"); diff --git a/python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh b/python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh index ca80e1efcdf1..be59e2c738f4 100644 --- a/python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh +++ b/python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh @@ -296,10 +296,13 @@ struct FusedParallelQKNormAcrossHead : public CustomAllReduceBase { const auto needed_buffer_bytes = static_cast(num_tokens) * 2 * sizeof(float); RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch"); RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized"); - RuntimeCheck(std::bit_cast(params.q_ptr) % 16 == 0, "q pointer is not properly aligned"); - RuntimeCheck(std::bit_cast(params.k_ptr) % 16 == 0, "k pointer is not properly aligned"); - RuntimeCheck(std::bit_cast(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned"); - RuntimeCheck(std::bit_cast(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned"); + // ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT + // builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast + // is value-equivalent for pointer-to-integer. + RuntimeCheck(reinterpret_cast(params.q_ptr) % 16 == 0, "q pointer is not properly aligned"); + RuntimeCheck(reinterpret_cast(params.k_ptr) % 16 == 0, "k pointer is not properly aligned"); + RuntimeCheck(reinterpret_cast(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned"); + RuntimeCheck(reinterpret_cast(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned"); RuntimeCheck(needed_buffer_bytes <= m_push_buffer_bytes, "Push buffer is too small"); LaunchKernel(num_blocks, num_threads, device) // diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index e6a353e9bfd9..9d29487e6220 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -32,16 +32,34 @@ _is_hip = is_hip() -def _get_block_sizes_for_extend_attention(Lq: int, Lv: int): +def _get_block_sizes_for_extend_attention( + Lq: int, + Lv: int, + *, + batch_size: int = 0, + max_len_extend: int = 0, +): """ Get block sizes and configuration for extend attention kernels. Args: Lq: Query head dimension Lv: Value head dimension + batch_size: Number of sequences in the batch (kw-only). Used by the + H100 (sm_90, Lq<=256) heuristic to pick a smaller tile for + high-bs spec-decode verify shapes where the default (128, 64, w8) + wastes work per program. ``0`` (default) is treated as "unknown" + and preserves the legacy tile. + max_len_extend: Maximum extend length per sequence in the batch + (kw-only). Used together with batch_size to distinguish + high-bs *verify* shapes (small max_len_extend, e.g. 4 for + num_draft_tokens=4) from high-bs *chunked prefill* shapes + (larger max_len_extend). ``0`` (default) is treated as + "unknown" and falls back to the long-extend tile. Returns: - tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps) + tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, + num_stages) """ # Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension if Lq == 576: @@ -59,6 +77,8 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int): BLOCK_DV = triton.next_power_of_2(Lv) + num_stages = 1 + # Determine BLOCK_M, BLOCK_N, and num_warps based on hardware if _is_hip: BLOCK_M, BLOCK_N = (64, 64) @@ -82,8 +102,48 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int): BLOCK_M, BLOCK_N = (16, 64) elif _is_cuda and CUDA_CAPABILITY[0] >= 9: # Hopper architecture (H100, etc.) - if Lq <= 256: + if Lq <= 128: BLOCK_M, BLOCK_N = (128, 64) + elif Lq <= 256: + # H100 / sm_90, head_dim == 256 (e.g. Gemma-4-26B-A4B-IT, + # which uses head_dim=256). The legacy (128, 64, w8, s1) + # tile is severely oversized for both the long-extend + # initial-prefill shape (bs=1, ext=8k) and the high-bs + # MTP verify shape (bs=32, ext=4, prefix>=1k) — see + # the microbench in the H100 SOTA run artifact dir + # ``patches/bench_extend_attn_gemma4_26b.py`` (and the + # ``patches/extend_attn_microbench_*.log`` artifacts). + # Microbench winners on bf16, num_q_heads=8, num_kv_heads=4: + # prefill long ext=8192 bs=1 2657us -> 1908us -28% (32,64,w4,s2) + # prefill chat ext=1000 bs=1 128us -> 56us -56% (32,64,w4,s2) + # verify chat ext=4 pf=1000 bs=32 616us -> 144us -77% (16,64,w4,s2) + # verify summ ext=4 pf=8000 bs=32 1076us-> 191us -82% (16,64,w4,s2) + # verify burst ext=4 pf=64 bs=32 94us -> 22us -77% (32,32,w4,s2) + # chunked-prefill ext=512 bs=8 136us -> 92us -32% (32,64,w4,s2) + # chunked-prefill ext=1024 bs=16 752us -> 559us -26% (32,64,w4,s2) + # The (16, 64, w4, s2) tile that dominates the high-bs + # *verify* path (max_len_extend = num_draft_tokens, very + # small) regresses the high-bs *chunked-prefill* path + # (max_len_extend = chunked_prefill_size_per_seq, larger) + # by ~30 %. Gate on BOTH batch_size and max_len_extend + # so chunked prefill keeps (32, 64, w4, s2). + if batch_size >= 8 and 0 < max_len_extend <= 16: + BLOCK_M, BLOCK_N = (16, 64) + num_warps = 4 + num_stages = 2 + else: + BLOCK_M, BLOCK_N = (32, 64) + num_warps = 4 + num_stages = 2 + return ( + BLOCK_DMODEL, + BLOCK_DPE, + BLOCK_DV, + BLOCK_M, + BLOCK_N, + num_warps, + num_stages, + ) else: BLOCK_M, BLOCK_N = (32, 64) elif _is_cuda and CUDA_CAPABILITY[0] >= 8: @@ -109,7 +169,7 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int): num_warps = 4 if Lq <= 64 else 8 - return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps + return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages @triton.jit @@ -591,15 +651,19 @@ def extend_attention_fwd( v_extend.shape[-1], ) - # Get block sizes and configuration - BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = ( - _get_block_sizes_for_extend_attention(Lq, Lv) - ) - sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] + # Get block sizes and configuration. Pass batch_size + max_len_extend so + # the H100 Lq<=256 heuristic can pick the spec-decode-verify tile + # (only when extend is tiny) vs the chunked-prefill / long-extend tile. + BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = ( + _get_block_sizes_for_extend_attention( + Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend + ) + ) + USE_CUSTOM_MASK = custom_mask is not None # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask @@ -607,7 +671,6 @@ def extend_attention_fwd( HAS_SINK = sinks is not None grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) - num_stages = 1 extra_kargs = {} if _is_hip: @@ -1001,15 +1064,19 @@ def extend_attention_fwd_unified( """ Lq, Lv = q.shape[-1], v_buffer.shape[-1] - # Get block sizes and configuration - BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = ( - _get_block_sizes_for_extend_attention(Lq, Lv) - ) - sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] + # Get block sizes and configuration. Pass batch_size + max_len_extend so + # the H100 Lq<=256 heuristic can pick the spec-decode-verify tile + # (only when extend is tiny) vs the chunked-prefill / long-extend tile. + BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = ( + _get_block_sizes_for_extend_attention( + Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend + ) + ) + USE_CUSTOM_MASK = custom_mask is not None HAS_SINK = sinks is not None @@ -1020,7 +1087,6 @@ def extend_attention_fwd_unified( window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device) grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) - num_stages = 1 extra_kargs = {} if _is_hip: