Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ struct CustomAllReducePull : public CustomAllReduceBase {
RuntimeCheck(shot == 1 || shot == 2, "Invalid shot count: ", shot);
RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported");
RuntimeCheck(is_type<DType>(input.dtype()), "Input dtype mismatch");
RuntimeCheck(std::bit_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
RuntimeCheck(m_pull_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(static_cast<int64_t>(num_items) == num_items_int64, "Number of items exceeds 4G limit");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ struct CustomAllReducePush : public CustomAllReduceBase {
RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch");
RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported");
RuntimeCheck(is_type<DType>(input.dtype()), "Input dtype mismatch");
RuntimeCheck(std::bit_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(input_ptr) % 16 == 0, "Input pointer is not properly aligned");
RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(shot == 1, "Push all-reduce only supports 1-shot, got: ", shot);
RuntimeCheck(static_cast<int64_t>(num_items) == num_items_int64, "Number of items exceeds 4G limit");
Expand Down
11 changes: 7 additions & 4 deletions python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,13 @@ struct FusedParallelQKNormAcrossHead : public CustomAllReduceBase {
const auto needed_buffer_bytes = static_cast<int64_t>(num_tokens) * 2 * sizeof(float);
RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch");
RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized");
RuntimeCheck(std::bit_cast<intptr_t>(params.q_ptr) % 16 == 0, "q pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.k_ptr) % 16 == 0, "k pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned");
RuntimeCheck(std::bit_cast<intptr_t>(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned");
// ``reinterpret_cast`` rather than ``std::bit_cast`` so the JIT
// builds on libstdc++ < 11 (gcc 10 ships in Debian 11). The cast
// is value-equivalent for pointer-to-integer.
RuntimeCheck(reinterpret_cast<intptr_t>(params.q_ptr) % 16 == 0, "q pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.k_ptr) % 16 == 0, "k pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.q_weight) % 16 == 0, "q_weight pointer is not properly aligned");
RuntimeCheck(reinterpret_cast<intptr_t>(params.k_weight) % 16 == 0, "k_weight pointer is not properly aligned");
RuntimeCheck(needed_buffer_bytes <= m_push_buffer_bytes, "Push buffer is too small");

LaunchKernel(num_blocks, num_threads, device) //
Expand Down
98 changes: 82 additions & 16 deletions python/sglang/srt/layers/attention/triton_ops/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,34 @@
_is_hip = is_hip()


def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
def _get_block_sizes_for_extend_attention(
Lq: int,
Lv: int,
*,
batch_size: int = 0,
max_len_extend: int = 0,
):
"""
Get block sizes and configuration for extend attention kernels.

Args:
Lq: Query head dimension
Lv: Value head dimension
batch_size: Number of sequences in the batch (kw-only). Used by the
H100 (sm_90, Lq<=256) heuristic to pick a smaller tile for
high-bs spec-decode verify shapes where the default (128, 64, w8)
wastes work per program. ``0`` (default) is treated as "unknown"
and preserves the legacy tile.
max_len_extend: Maximum extend length per sequence in the batch
(kw-only). Used together with batch_size to distinguish
high-bs *verify* shapes (small max_len_extend, e.g. 4 for
num_draft_tokens=4) from high-bs *chunked prefill* shapes
(larger max_len_extend). ``0`` (default) is treated as
"unknown" and falls back to the long-extend tile.

Returns:
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps)
tuple: (BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps,
num_stages)
"""
# Determine BLOCK_DMODEL and BLOCK_DPE based on head dimension
if Lq == 576:
Expand All @@ -59,6 +77,8 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):

BLOCK_DV = triton.next_power_of_2(Lv)

num_stages = 1

# Determine BLOCK_M, BLOCK_N, and num_warps based on hardware
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
Expand All @@ -82,8 +102,48 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):
BLOCK_M, BLOCK_N = (16, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 9:
# Hopper architecture (H100, etc.)
if Lq <= 256:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 64)
elif Lq <= 256:
# H100 / sm_90, head_dim == 256 (e.g. Gemma-4-26B-A4B-IT,
# which uses head_dim=256). The legacy (128, 64, w8, s1)
# tile is severely oversized for both the long-extend
# initial-prefill shape (bs=1, ext=8k) and the high-bs
# MTP verify shape (bs=32, ext=4, prefix>=1k) — see
# the microbench in the H100 SOTA run artifact dir
# ``patches/bench_extend_attn_gemma4_26b.py`` (and the
# ``patches/extend_attn_microbench_*.log`` artifacts).
# Microbench winners on bf16, num_q_heads=8, num_kv_heads=4:
# prefill long ext=8192 bs=1 2657us -> 1908us -28% (32,64,w4,s2)
# prefill chat ext=1000 bs=1 128us -> 56us -56% (32,64,w4,s2)
# verify chat ext=4 pf=1000 bs=32 616us -> 144us -77% (16,64,w4,s2)
# verify summ ext=4 pf=8000 bs=32 1076us-> 191us -82% (16,64,w4,s2)
# verify burst ext=4 pf=64 bs=32 94us -> 22us -77% (32,32,w4,s2)
# chunked-prefill ext=512 bs=8 136us -> 92us -32% (32,64,w4,s2)
# chunked-prefill ext=1024 bs=16 752us -> 559us -26% (32,64,w4,s2)
# The (16, 64, w4, s2) tile that dominates the high-bs
# *verify* path (max_len_extend = num_draft_tokens, very
# small) regresses the high-bs *chunked-prefill* path
# (max_len_extend = chunked_prefill_size_per_seq, larger)
# by ~30 %. Gate on BOTH batch_size and max_len_extend
# so chunked prefill keeps (32, 64, w4, s2).
if batch_size >= 8 and 0 < max_len_extend <= 16:
BLOCK_M, BLOCK_N = (16, 64)
num_warps = 4
num_stages = 2
else:
BLOCK_M, BLOCK_N = (32, 64)
num_warps = 4
num_stages = 2
return (
BLOCK_DMODEL,
BLOCK_DPE,
BLOCK_DV,
BLOCK_M,
BLOCK_N,
num_warps,
num_stages,
)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
Expand All @@ -109,7 +169,7 @@ def _get_block_sizes_for_extend_attention(Lq: int, Lv: int):

num_warps = 4 if Lq <= 64 else 8

return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps
return BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages


@triton.jit
Expand Down Expand Up @@ -591,23 +651,26 @@ def extend_attention_fwd(
v_extend.shape[-1],
)

# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)

sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]

# Get block sizes and configuration. Pass batch_size + max_len_extend so
# the H100 Lq<=256 heuristic can pick the spec-decode-verify tile
# (only when extend is tiny) vs the chunked-prefill / long-extend tile.
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = (
_get_block_sizes_for_extend_attention(
Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend
)
)

USE_CUSTOM_MASK = custom_mask is not None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask

HAS_SINK = sinks is not None

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1

extra_kargs = {}
if _is_hip:
Expand Down Expand Up @@ -1001,15 +1064,19 @@ def extend_attention_fwd_unified(
"""
Lq, Lv = q.shape[-1], v_buffer.shape[-1]

# Get block sizes and configuration
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps = (
_get_block_sizes_for_extend_attention(Lq, Lv)
)

sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]

# Get block sizes and configuration. Pass batch_size + max_len_extend so
# the H100 Lq<=256 heuristic can pick the spec-decode-verify tile
# (only when extend is tiny) vs the chunked-prefill / long-extend tile.
BLOCK_DMODEL, BLOCK_DPE, BLOCK_DV, BLOCK_M, BLOCK_N, num_warps, num_stages = (
_get_block_sizes_for_extend_attention(
Lq, Lv, batch_size=batch_size, max_len_extend=max_len_extend
)
)

USE_CUSTOM_MASK = custom_mask is not None
HAS_SINK = sinks is not None

Expand All @@ -1020,7 +1087,6 @@ def extend_attention_fwd_unified(
window_start_pos = torch.zeros(batch_size, dtype=torch.int32, device=q.device)

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1

extra_kargs = {}
if _is_hip:
Expand Down
Loading