diff --git a/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py b/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py index dfdcbd5862..cecfa6c8c7 100644 --- a/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py +++ b/aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py @@ -5,45 +5,55 @@ import triton.language as tl +def _next_pow2(n): + """Return the smallest power of 2 >= n (Python-side helper, not a JIT function).""" + return 1 << (n - 1).bit_length() + + @triton.jit def _load_unshuffle_segment( base_ptr, seg_idx, - QkNopeHeadDim: tl.constexpr, + HeadDim: tl.constexpr, + PaddedHeadDim: tl.constexpr, KV_CDim: tl.constexpr, ScaleKGranularity: tl.constexpr, ): - """Load one [QkNopeHeadDim, ScaleKGranularity] weight segment from a + """Load one [PaddedHeadDim, ScaleKGranularity] weight segment from a preshuffled weight matrix via coalesced row-major loads, then unshuffle - in registers. - - Each n_blk (16 original N rows) occupies KV_CDim//32 shuffled rows. - A ScaleK segment of 128 K values spans SegKBlocks=4 consecutive rows - within each n_blk. We gather these rows across all n_blks (with row - stride KV_CDim), producing a [NumNBlk*SegKBlocks, KV_CDim] tensor, - then reshape + permute to recover [QkNopeHeadDim, ScaleKGranularity]. + in registers. PaddedHeadDim is HeadDim rounded up to the next power of 2. + Out-of-range rows are zero-filled so dot-products stay correct. """ - NumNBlk: tl.constexpr = QkNopeHeadDim // 16 + NumNBlk: tl.constexpr = HeadDim // 16 + PaddedNumNBlk: tl.constexpr = PaddedHeadDim // 16 SegKBlocks: tl.constexpr = ScaleKGranularity // 32 NumKBlkTotal: tl.constexpr = KV_CDim // 32 - TotalRows: tl.constexpr = NumNBlk * SegKBlocks + PaddedTotalRows: tl.constexpr = PaddedNumNBlk * SegKBlocks - offs_nb = tl.arange(0, NumNBlk) + offs_nb = tl.arange(0, PaddedNumNBlk) offs_kb = tl.arange(0, SegKBlocks) row_indices = ( offs_nb[:, None] * NumKBlkTotal + seg_idx * SegKBlocks + offs_kb[None, :] ) - row_indices_flat = tl.reshape(row_indices, (TotalRows,)) + row_indices_flat = tl.reshape(row_indices, (PaddedTotalRows,)) + mask_flat = tl.reshape( + (offs_nb[:, None] < NumNBlk).broadcast_to(PaddedNumNBlk, SegKBlocks), + (PaddedTotalRows,), + ) offs_col = tl.arange(0, KV_CDim) - raw = tl.load(base_ptr + row_indices_flat[:, None] * KV_CDim + offs_col[None, :]) + raw = tl.load( + base_ptr + row_indices_flat[:, None] * KV_CDim + offs_col[None, :], + mask=mask_flat[:, None], + other=0.0, + ) w = tl.reshape( tl.permute( - tl.reshape(raw, (NumNBlk, SegKBlocks, 2, 16, 16)), + tl.reshape(raw, (PaddedNumNBlk, SegKBlocks, 2, 16, 16)), (0, 3, 1, 2, 4), ), - (QkNopeHeadDim, ScaleKGranularity), + (PaddedHeadDim, ScaleKGranularity), ) return w @@ -56,22 +66,25 @@ def _triton_gather_kv_b_proj( kv_indptr, # [batch_size + 1] kv_indices, # [total_kv] kv_prefix_sum_context_lens, # [batch_size + 1] - kv_proj_weight, # [tp_k_head_num * 2 * qk_nope_head_dim, kv_c_dim] - kv_proj_scale, # block: [n//128, k//128]; per-row: [tp_heads * 2 * qk_nope_head_dim] - k_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim + kv_pe_dim] - v_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim] + kv_proj_weight, # [tp_k_head_num * (qk_nope_head_dim + v_head_dim), kv_c_dim] + kv_proj_scale, # block: [n//128, k//128]; per-row: [weight_n] or [weight_n, 1] + k_prefix, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim] + v_prefix, # [total_kv, tp_k_head_num, v_head_dim] KBlockSize: tl.constexpr, TpNumHeads: tl.constexpr, QkNopeHeadDim: tl.constexpr, + VHeadDim: tl.constexpr, KV_CDim: tl.constexpr, KV_PeDim: tl.constexpr, ChunkK: tl.constexpr, + PaddedK: tl.constexpr, + PaddedV: tl.constexpr, WEIGHT_PRESHUFFLE: tl.constexpr = False, PER_ROW_SCALE: tl.constexpr = False, ): stride_k_buffer: tl.constexpr = KBlockSize * (KV_CDim + KV_PeDim) stride_k_prefix: tl.constexpr = TpNumHeads * (QkNopeHeadDim + KV_PeDim) - stride_v_prefix: tl.constexpr = TpNumHeads * QkNopeHeadDim + stride_v_prefix: tl.constexpr = TpNumHeads * VHeadDim ScaleKGranularity: tl.constexpr = 128 ScaleNGranularity: tl.constexpr = 128 @@ -103,112 +116,150 @@ def _triton_gather_kv_b_proj( else: k_scalar_scale = tl.load(k_scale) - offs_n = tl.arange(0, QkNopeHeadDim) + offs_n_k = tl.arange(0, PaddedK) + offs_n_v = tl.arange(0, PaddedV) + mask_k = offs_n_k < QkNopeHeadDim + mask_v = offs_n_v < VHeadDim offs_k = tl.arange(0, ScaleKGranularity) - k_head_base = kv_proj_weight + pid_head * 2 * QkNopeHeadDim * KV_CDim + k_head_base = kv_proj_weight + pid_head * (QkNopeHeadDim + VHeadDim) * KV_CDim v_head_base = k_head_base + QkNopeHeadDim * KV_CDim if PER_ROW_SCALE: - k_row0 = pid_head * 2 * QkNopeHeadDim - k_nope_scale_vec = tl.load(kv_proj_scale + k_row0 + offs_n).to(tl.float32) - v_nope_scale_vec = tl.load(kv_proj_scale + k_row0 + QkNopeHeadDim + offs_n).to( - tl.float32 - ) + k_row0 = pid_head * (QkNopeHeadDim + VHeadDim) + k_nope_scale_vec = tl.load( + kv_proj_scale + k_row0 + offs_n_k, mask=mask_k, other=1.0 + ).to(tl.float32) + v_nope_scale_vec = tl.load( + kv_proj_scale + k_row0 + QkNopeHeadDim + offs_n_v, mask=mask_v, other=1.0 + ).to(tl.float32) else: - k_nope_scale_base_offset = ( - kv_proj_scale - + pid_head - * 2 - * QkNopeHeadDim - * KV_CDim - // ScaleKGranularity - // ScaleNGranularity - + tl.arange(0, QkNopeHeadDim // ScaleNGranularity) - * (KV_CDim // ScaleKGranularity) - ) + num_scale_cols: tl.constexpr = KV_CDim // ScaleKGranularity + k_abs_rows = pid_head * (QkNopeHeadDim + VHeadDim) + offs_n_k + k_scale_n_idx = k_abs_rows // ScaleNGranularity + v_abs_rows = pid_head * (QkNopeHeadDim + VHeadDim) + QkNopeHeadDim + offs_n_v + v_scale_n_idx = v_abs_rows // ScaleNGranularity if WEIGHT_PRESHUFFLE: + # _load_unshuffle_segment returns [PaddedHeadDim, ScaleKGranularity] + # with zero-filled rows beyond HeadDim k_nope_weight_0 = _load_unshuffle_segment( - k_head_base, 0, QkNopeHeadDim, KV_CDim, ScaleKGranularity + k_head_base, 0, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity ).to(k_type) k_nope_weight_1 = _load_unshuffle_segment( - k_head_base, 1, QkNopeHeadDim, KV_CDim, ScaleKGranularity + k_head_base, 1, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity ).to(k_type) k_nope_weight_2 = _load_unshuffle_segment( - k_head_base, 2, QkNopeHeadDim, KV_CDim, ScaleKGranularity + k_head_base, 2, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity ).to(k_type) k_nope_weight_3 = _load_unshuffle_segment( - k_head_base, 3, QkNopeHeadDim, KV_CDim, ScaleKGranularity + k_head_base, 3, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity ).to(k_type) v_nope_weight_0 = _load_unshuffle_segment( - v_head_base, 0, QkNopeHeadDim, KV_CDim, ScaleKGranularity + v_head_base, 0, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity ).to(k_type) v_nope_weight_1 = _load_unshuffle_segment( - v_head_base, 1, QkNopeHeadDim, KV_CDim, ScaleKGranularity + v_head_base, 1, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity ).to(k_type) v_nope_weight_2 = _load_unshuffle_segment( - v_head_base, 2, QkNopeHeadDim, KV_CDim, ScaleKGranularity + v_head_base, 2, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity ).to(k_type) v_nope_weight_3 = _load_unshuffle_segment( - v_head_base, 3, QkNopeHeadDim, KV_CDim, ScaleKGranularity + v_head_base, 3, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity ).to(k_type) else: k_nope_weight_base_offset = ( - k_head_base + offs_n[:, None] * KV_CDim + offs_k[None, :] - ) - k_nope_weight_0 = tl.load(k_nope_weight_base_offset + 0 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_1 = tl.load(k_nope_weight_base_offset + 1 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_2 = tl.load(k_nope_weight_base_offset + 2 * ScaleKGranularity).to( - k_type - ) - k_nope_weight_3 = tl.load(k_nope_weight_base_offset + 3 * ScaleKGranularity).to( - k_type + k_head_base + offs_n_k[:, None] * KV_CDim + offs_k[None, :] ) + k_mask_2d = mask_k[:, None] + k_nope_weight_0 = tl.load( + k_nope_weight_base_offset + 0 * ScaleKGranularity, + mask=k_mask_2d, + other=0.0, + ).to(k_type) + k_nope_weight_1 = tl.load( + k_nope_weight_base_offset + 1 * ScaleKGranularity, + mask=k_mask_2d, + other=0.0, + ).to(k_type) + k_nope_weight_2 = tl.load( + k_nope_weight_base_offset + 2 * ScaleKGranularity, + mask=k_mask_2d, + other=0.0, + ).to(k_type) + k_nope_weight_3 = tl.load( + k_nope_weight_base_offset + 3 * ScaleKGranularity, + mask=k_mask_2d, + other=0.0, + ).to(k_type) + v_nope_weight_base_offset = ( + v_head_base + offs_n_v[:, None] * KV_CDim + offs_k[None, :] + ) + v_mask_2d = mask_v[:, None] v_nope_weight_0 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 0 * ScaleKGranularity + v_nope_weight_base_offset + 0 * ScaleKGranularity, + mask=v_mask_2d, + other=0.0, ).to(k_type) v_nope_weight_1 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 1 * ScaleKGranularity + v_nope_weight_base_offset + 1 * ScaleKGranularity, + mask=v_mask_2d, + other=0.0, ).to(k_type) v_nope_weight_2 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 2 * ScaleKGranularity + v_nope_weight_base_offset + 2 * ScaleKGranularity, + mask=v_mask_2d, + other=0.0, ).to(k_type) v_nope_weight_3 = tl.load( - k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 3 * ScaleKGranularity + v_nope_weight_base_offset + 3 * ScaleKGranularity, + mask=v_mask_2d, + other=0.0, ).to(k_type) if not PER_ROW_SCALE: - k_nope_scale_0 = tl.load(k_nope_scale_base_offset + 0) - k_nope_scale_1 = tl.load(k_nope_scale_base_offset + 1) - k_nope_scale_2 = tl.load(k_nope_scale_base_offset + 2) - k_nope_scale_3 = tl.load(k_nope_scale_base_offset + 3) + k_nope_scale_0 = tl.load( + kv_proj_scale + k_scale_n_idx * num_scale_cols + 0, + mask=mask_k, + other=0.0, + ).to(tl.float32) + k_nope_scale_1 = tl.load( + kv_proj_scale + k_scale_n_idx * num_scale_cols + 1, + mask=mask_k, + other=0.0, + ).to(tl.float32) + k_nope_scale_2 = tl.load( + kv_proj_scale + k_scale_n_idx * num_scale_cols + 2, + mask=mask_k, + other=0.0, + ).to(tl.float32) + k_nope_scale_3 = tl.load( + kv_proj_scale + k_scale_n_idx * num_scale_cols + 3, + mask=mask_k, + other=0.0, + ).to(tl.float32) v_nope_scale_0 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 0 - ) + kv_proj_scale + v_scale_n_idx * num_scale_cols + 0, + mask=mask_v, + other=0.0, + ).to(tl.float32) v_nope_scale_1 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 1 - ) + kv_proj_scale + v_scale_n_idx * num_scale_cols + 1, + mask=mask_v, + other=0.0, + ).to(tl.float32) v_nope_scale_2 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 2 - ) + kv_proj_scale + v_scale_n_idx * num_scale_cols + 2, + mask=mask_v, + other=0.0, + ).to(tl.float32) v_nope_scale_3 = tl.load( - k_nope_scale_base_offset - + QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity - + 3 - ) + kv_proj_scale + v_scale_n_idx * num_scale_cols + 3, + mask=mask_v, + other=0.0, + ).to(tl.float32) for chunk_id in range(total_kv_chunk): kv_block_idx = tl.load( @@ -225,8 +276,8 @@ def _triton_gather_kv_b_proj( + tl.arange(0, ScaleKGranularity)[None, :] ) # [ChunkK, kv_c_dim] - accum_k = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) - accum_v = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32) + accum_k = tl.zeros((ChunkK, PaddedK), dtype=tl.float32) + accum_v = tl.zeros((ChunkK, PaddedV), dtype=tl.float32) kv_c_data_0 = tl.load(k_buffer + kv_c_data_base_offset + 0 * ScaleKGranularity) kv_c_data_1 = tl.load(k_buffer + kv_c_data_base_offset + 1 * ScaleKGranularity) @@ -266,14 +317,14 @@ def _triton_gather_kv_b_proj( tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_vec[None, :] ) else: - accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0 - accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0 - accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1 - accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1 - accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2 - accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2 - accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3 - accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3 + accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0[None, :] + accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0[None, :] + accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1[None, :] + accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1[None, :] + accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2[None, :] + accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2[None, :] + accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3[None, :] + accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3[None, :] accum_k *= k_scalar_scale accum_v *= k_scalar_scale @@ -297,16 +348,16 @@ def _triton_gather_kv_b_proj( + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] * stride_k_prefix + pid_head * (QkNopeHeadDim + KV_PeDim) - + tl.arange(0, QkNopeHeadDim)[None, :], + + offs_n_k[None, :], accum_k, - mask=context_mask[:, None], + mask=context_mask[:, None] & mask_k[None, :], ) tl.store( v_prefix + (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None] * stride_v_prefix - + pid_head * QkNopeHeadDim - + tl.arange(0, QkNopeHeadDim)[None, :], + + pid_head * VHeadDim + + offs_n_v[None, :], accum_v, - mask=context_mask[:, None], + mask=context_mask[:, None] & mask_v[None, :], ) diff --git a/aiter/ops/triton/gather_kv_b_proj.py b/aiter/ops/triton/gather_kv_b_proj.py index 81db9d2edb..f9310bd159 100644 --- a/aiter/ops/triton/gather_kv_b_proj.py +++ b/aiter/ops/triton/gather_kv_b_proj.py @@ -4,6 +4,7 @@ import torch from aiter.ops.triton._triton_kernels.gather_kv_b_proj import ( + _next_pow2, _triton_gather_kv_b_proj, ) @@ -14,17 +15,19 @@ def gather_kv_b_proj( kv_indptr: torch.Tensor, # [batch_size + 1] kv_indices: torch.Tensor, # len(kv_indices) = kv_indptr[-1] kv_prefix_sum_context_lens: torch.Tensor, # [batch_size + 1] - kv_proj_weight: torch.Tensor, # [2 * 128 // TP * 128, 512] + kv_proj_weight: torch.Tensor, # [tp_heads * (qk_nope_head_dim + v_head_dim), kv_c_dim] kv_proj_scale: torch.Tensor, # [weight_n] per-output-row, or [N//128, K//128] block k_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim] - v_prefix: torch.Tensor, # [total_kv, tp_k_head_num, qk_nope_head_dim] + v_prefix: torch.Tensor, # [total_kv, tp_k_head_num, v_head_dim] weight_preshuffle: bool = False, ): num_block, block_size, hidden_dim = k_buffer.shape batch_size = kv_indptr.shape[0] - 1 weight_n, weight_k = kv_proj_weight.shape total_kv_k, tp_k_head_num_k, qk_nope_pe_dim = k_prefix.shape - total_kv_v, tp_k_head_num_v, qk_nope_dim = v_prefix.shape + total_kv_v, tp_k_head_num_v, v_head_dim = v_prefix.shape + + qk_nope_head_dim = weight_n // tp_k_head_num_k - v_head_dim per_row_scale = kv_proj_scale.dim() == 1 or ( kv_proj_scale.dim() == 2 and kv_proj_scale.shape[1] == 1 @@ -47,6 +50,9 @@ def gather_kv_b_proj( assert tp_k_head_num_k == tp_k_head_num_v assert ChunkK % block_size == 0 + padded_k = _next_pow2(qk_nope_head_dim) + padded_v = _next_pow2(v_head_dim) + grid = (batch_size * tp_k_head_num_k,) _triton_gather_kv_b_proj[grid]( batch_size, @@ -61,10 +67,13 @@ def gather_kv_b_proj( v_prefix, KBlockSize=block_size, TpNumHeads=tp_k_head_num_k, - QkNopeHeadDim=qk_nope_dim, + QkNopeHeadDim=qk_nope_head_dim, + VHeadDim=v_head_dim, KV_CDim=weight_k, - KV_PeDim=qk_nope_pe_dim - qk_nope_dim, + KV_PeDim=qk_nope_pe_dim - qk_nope_head_dim, ChunkK=ChunkK, + PaddedK=padded_k, + PaddedV=padded_v, WEIGHT_PRESHUFFLE=weight_preshuffle, PER_ROW_SCALE=per_row_scale, num_stages=3, diff --git a/op_tests/triton_tests/test_gather_kv_b_proj.py b/op_tests/triton_tests/test_gather_kv_b_proj.py index d667fe2817..6fc964ec19 100644 --- a/op_tests/triton_tests/test_gather_kv_b_proj.py +++ b/op_tests/triton_tests/test_gather_kv_b_proj.py @@ -18,9 +18,14 @@ def ref_gather_kv_b_proj( kv_indptr: torch.Tensor, # [batch_size + 1] kv_indices: torch.Tensor, # len(kv_indices) = kv_indptr[-1] kv_prefix_sum_context_lens: torch.Tensor, # [batch_size + 1] - kv_proj_weight: torch.Tensor, # [2 * 128 // TP * 128, 512] + kv_proj_weight: torch.Tensor, # [tp_heads * (qk_nope_head_dim + v_head_dim), kv_c_dim] kv_proj_scale: torch.Tensor, # [weight_n] per-row or [N//128, K//128] block + qk_nope_head_dim: int = 128, + v_head_dim: int = None, ): + if v_head_dim is None: + v_head_dim = qk_nope_head_dim + batch_size = kv_indptr.shape[0] - 1 kv_c_dim = 512 @@ -41,9 +46,7 @@ def ref_gather_kv_b_proj( scale_granularity_k = weight_k // kv_proj_scale.shape[1] assert scale_granularity_k == 128 - num_tp = 2 * 128 * 128 // weight_n - tp_k_head_num = 128 // num_tp - qk_nope_head_dim = 128 + tp_k_head_num = weight_n // (qk_nope_head_dim + v_head_dim) kv_c, k_pe = k_buffer.split( [kv_c_dim, kv_pe_dim], dim=-1 @@ -56,12 +59,12 @@ def ref_gather_kv_b_proj( dtype=torch.bfloat16, ) v_prefix = torch.zeros( - (total_kv, tp_k_head_num * qk_nope_head_dim), + (total_kv, tp_k_head_num * v_head_dim), device=k_buffer.device, dtype=torch.bfloat16, ) k_prefix_tp = k_prefix.view(total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim) - v_prefix_tp = v_prefix.view(total_kv, tp_k_head_num, qk_nope_head_dim) + v_prefix_tp = v_prefix.view(total_kv, tp_k_head_num, v_head_dim) if not per_row_scale: kv_proj_scale_repeat = kv_proj_scale.repeat_interleave( @@ -114,19 +117,17 @@ def ref_gather_kv_b_proj( ] .to(torch.float32) .T - ) # [batch_kv, 2 * 128 // TP * 128] + ) kv_proj += kv_proj_tmp * kv_proj_scale_repeat[:, i].unsqueeze(0) kv_proj_tp = kv_proj.view( - context_end - context_start, tp_k_head_num, qk_nope_head_dim * 2 - ) # [batch_kv, tp_k_head_num, 2 * 128 // TP * 128] + context_end - context_start, tp_k_head_num, qk_nope_head_dim + v_head_dim + ) if k_buffer.dtype != torch.bfloat16: kv_proj_tp *= k_scale.unsqueeze(0).unsqueeze(1) - k_proj_tp, v_proj_tp = kv_proj_tp.split( - [qk_nope_head_dim, qk_nope_head_dim], dim=-1 - ) # [batch_kv, tp_k_head_num, 128 // TP * 128] + k_proj_tp, v_proj_tp = kv_proj_tp.split([qk_nope_head_dim, v_head_dim], dim=-1) k_prefix_tp[ context_start:context_end, @@ -138,6 +139,65 @@ def ref_gather_kv_b_proj( return (k_prefix, v_prefix) +def _make_kv_test_data( + batch_size, + block_size, + avg_kv_length, + kv_c_dim, + kv_pe_dim, + k_buffer_type, + device="cuda", +): + """Create common test data: k_buffer, k_scale, kv_indptr, kv_indices, etc.""" + num_block = 2 * avg_kv_length // block_size + + k_buffer = torch.randn( + (num_block, block_size, kv_c_dim + kv_pe_dim), + device=device, + dtype=torch.float32, + ).to(k_buffer_type) + k_scale = torch.randn(1, device=device, dtype=torch.float32).abs() + + var_ratio = 0.2 + context_lens = ( + torch.randint( + int((1 - var_ratio) * avg_kv_length), + int(((1 + var_ratio)) * avg_kv_length) + 1, + (batch_size,), + ) + .cuda() + .to(torch.int32) + ) + context_blocks = torch.div( + context_lens + block_size - 1, block_size, rounding_mode="trunc" + ) + + kv_indptr = torch.zeros((batch_size + 1,), device="cuda", dtype=torch.int32) + kv_indptr[1:] = torch.cumsum(context_blocks, dim=0) + + kv_prefix_sum_context_lens = torch.zeros( + (batch_size + 1,), device="cuda", dtype=torch.int32 + ) + kv_prefix_sum_context_lens[1:] = torch.cumsum(context_lens, dim=0) + + kv_indices = torch.zeros(kv_indptr[-1], device="cuda", dtype=torch.int32) + for b in range(batch_size): + ctx_len = int(context_blocks[b].item()) + kv_indices[kv_indptr[b] : kv_indptr[b + 1]] = torch.randperm( + num_block, device="cuda" + )[:ctx_len] + + return ( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + context_lens, + num_block, + ) + + @pytest.mark.parametrize( "batch_size, block_size, num_tp, k_buffer_type, avg_kv_length", [ @@ -157,6 +217,7 @@ def test_gather_kv_b_proj( kv_c_dim = 512 kv_pe_dim = 64 qk_nope_head_dim = 128 + v_head_dim = 128 tp_k_head_num = 128 // num_tp num_block = 2 * avg_kv_length // block_size @@ -204,13 +265,14 @@ def test_gather_kv_b_proj( )[:ctx_len] # Generate random kv_proj_weight and kv_proj_scale + weight_n = tp_k_head_num * (qk_nope_head_dim + v_head_dim) kv_proj_weight = torch.randn( - (2 * 128 // num_tp * qk_nope_head_dim, kv_c_dim), + (weight_n, kv_c_dim), device=device, dtype=torch.float32, ).to(weight_dtype) kv_proj_scale = torch.randn( - (2 * 128 // num_tp, 4), device=device, dtype=torch.float32 + (weight_n // 128, 4), device=device, dtype=torch.float32 ).abs() # Reference implementation @@ -222,6 +284,8 @@ def test_gather_kv_b_proj( kv_prefix_sum_context_lens, kv_proj_weight, kv_proj_scale, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, ) k_prefix = torch.zeros( @@ -233,7 +297,7 @@ def test_gather_kv_b_proj( dtype=torch.bfloat16, ) v_prefix = torch.zeros( - (kv_prefix_sum_context_lens[-1].item(), tp_k_head_num * qk_nope_head_dim), + (kv_prefix_sum_context_lens[-1].item(), tp_k_head_num * v_head_dim), device=device, dtype=torch.bfloat16, ) @@ -250,7 +314,7 @@ def test_gather_kv_b_proj( kv_proj_weight, kv_proj_scale, k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), - v_prefix.view(-1, tp_k_head_num, qk_nope_head_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), weight_preshuffle=weight_preshuffle, ) @@ -269,14 +333,14 @@ def test_gather_kv_b_proj( kv_proj_weight, kv_proj_scale, k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), - v_prefix.view(-1, tp_k_head_num, qk_nope_head_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), weight_preshuffle=weight_preshuffle, ) total_float_operations = ( 2 * context_lens.float().sum().item() - * (2 * tp_k_head_num * qk_nope_head_dim) - * kv_c_dim # gemm_m # gemm_n # gemm_k + * (tp_k_head_num * (qk_nope_head_dim + v_head_dim)) + * kv_c_dim ) tflops = total_float_operations / elapsed_us * 1e-6 @@ -304,12 +368,13 @@ def test_gather_kv_b_proj_per_row_scale( kv_c_dim = 512 kv_pe_dim = 64 qk_nope_head_dim = 128 + v_head_dim = 128 tp_k_head_num = 128 // num_tp num_block = 2 * avg_kv_length // block_size weight_preshuffle = True device = "cuda" weight_dtype = dtypes.fp8 - weight_n = 2 * 128 // num_tp * qk_nope_head_dim + weight_n = tp_k_head_num * (qk_nope_head_dim + v_head_dim) k_buffer = torch.randn( (num_block, block_size, kv_c_dim + kv_pe_dim), @@ -364,6 +429,8 @@ def test_gather_kv_b_proj_per_row_scale( kv_prefix_sum_context_lens, kv_proj_weight, kv_proj_scale, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, ) k_prefix = torch.zeros( @@ -375,7 +442,7 @@ def test_gather_kv_b_proj_per_row_scale( dtype=torch.bfloat16, ) v_prefix = torch.zeros( - (kv_prefix_sum_context_lens[-1].item(), tp_k_head_num * qk_nope_head_dim), + (kv_prefix_sum_context_lens[-1].item(), tp_k_head_num * v_head_dim), device=device, dtype=torch.bfloat16, ) @@ -392,7 +459,7 @@ def test_gather_kv_b_proj_per_row_scale( kv_proj_weight, kv_proj_scale, k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), - v_prefix.view(-1, tp_k_head_num, qk_nope_head_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), weight_preshuffle=weight_preshuffle, ) @@ -410,13 +477,13 @@ def test_gather_kv_b_proj_per_row_scale( kv_proj_weight, kv_proj_scale, k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), - v_prefix.view(-1, tp_k_head_num, qk_nope_head_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), weight_preshuffle=weight_preshuffle, ) total_float_operations = ( 2 * context_lens.float().sum().item() - * (2 * tp_k_head_num * qk_nope_head_dim) + * (tp_k_head_num * (qk_nope_head_dim + v_head_dim)) * kv_c_dim ) tflops = total_float_operations / elapsed_us * 1e-6 @@ -428,6 +495,277 @@ def test_gather_kv_b_proj_per_row_scale( ) +@pytest.mark.parametrize( + "batch_size, block_size, num_tp, k_buffer_type, avg_kv_length, scale_mode", + [ + (4, 1, 4, torch.bfloat16, 512, "block"), + (8, 16, 4, torch.bfloat16, 1024, "block"), + (4, 1, 4, dtypes.fp8, 512, "block"), + (8, 16, 4, dtypes.fp8, 1024, "block"), + (4, 1, 4, torch.bfloat16, 512, "per_row"), + (8, 16, 4, torch.bfloat16, 1024, "per_row"), + (4, 1, 4, dtypes.fp8, 512, "per_row"), + (8, 16, 4, dtypes.fp8, 1024, "per_row"), + ], +) +def test_gather_kv_b_proj_bf16_weight( + batch_size, block_size, num_tp, k_buffer_type, avg_kv_length, scale_mode, perf=False +): + """Test gather_kv_b_proj with bf16 weight (no quantization on weight). + + When weight is bf16, weight_scale is set to all-ones so the matmul result + is not scaled — matching the behavior of an unquantized kv_b_proj. + """ + torch.manual_seed(0) + random.seed(0) + kv_c_dim = 512 + kv_pe_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + tp_k_head_num = 128 // num_tp + num_block = 2 * avg_kv_length // block_size + weight_preshuffle = True + device = "cuda" + weight_dtype = torch.bfloat16 + weight_n = tp_k_head_num * (qk_nope_head_dim + v_head_dim) + + k_buffer = torch.randn( + (num_block, block_size, kv_c_dim + kv_pe_dim), + device=device, + dtype=torch.float32, + ).to(k_buffer_type) + k_scale = torch.randn(1, device=device, dtype=torch.float32).abs() + + var_ratio = 0.2 + context_lens = ( + torch.randint( + int((1 - var_ratio) * avg_kv_length), + int(((1 + var_ratio)) * avg_kv_length) + 1, + (batch_size,), + ) + .cuda() + .to(torch.int32) + ) + context_blocks = torch.div( + context_lens + block_size - 1, block_size, rounding_mode="trunc" + ) + + kv_indptr = torch.zeros((batch_size + 1,), device="cuda", dtype=torch.int32) + kv_indptr[1:] = torch.cumsum(context_blocks, dim=0) + + kv_prefix_sum_context_lens = torch.zeros( + (batch_size + 1,), device="cuda", dtype=torch.int32 + ) + kv_prefix_sum_context_lens[1:] = torch.cumsum(context_lens, dim=0) + + kv_indices = torch.zeros(kv_indptr[-1], device="cuda", dtype=torch.int32) + for b in range(batch_size): + ctx_len = int(context_blocks[b].item()) + kv_indices[kv_indptr[b] : kv_indptr[b + 1]] = torch.randperm( + num_block, device="cuda" + )[:ctx_len] + + # bf16 weight — no quantization + kv_proj_weight = torch.randn( + (weight_n, kv_c_dim), + device=device, + dtype=torch.float32, + ).to(weight_dtype) + + # Use all-ones scale to simulate no weight quantization + if scale_mode == "per_row": + kv_proj_scale = torch.ones((weight_n, 1), device=device, dtype=torch.float32) + else: + kv_proj_scale = torch.ones( + (weight_n // 128, kv_c_dim // 128), device=device, dtype=torch.float32 + ) + + k_ref, v_ref = ref_gather_kv_b_proj( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, + ) + + total_kv = kv_prefix_sum_context_lens[-1].item() + k_prefix = torch.zeros( + (total_kv, tp_k_head_num * (qk_nope_head_dim + kv_pe_dim)), + device=device, + dtype=torch.bfloat16, + ) + v_prefix = torch.zeros( + (total_kv, tp_k_head_num * v_head_dim), + device=device, + dtype=torch.bfloat16, + ) + + if weight_preshuffle: + kv_proj_weight = shuffle_weight(kv_proj_weight) + + gather_kv_b_proj( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), + weight_preshuffle=weight_preshuffle, + ) + + checkAllclose(k_ref, k_prefix, atol=1e-2, rtol=1e-2) + checkAllclose(v_ref, v_prefix, atol=1e-2, rtol=1e-2) + + if perf: + _, elapsed_us = run_perftest( + gather_kv_b_proj, + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), + weight_preshuffle=weight_preshuffle, + ) + total_float_operations = ( + 2 + * context_lens.float().sum().item() + * (tp_k_head_num * (qk_nope_head_dim + v_head_dim)) + * kv_c_dim + ) + tflops = total_float_operations / elapsed_us * 1e-6 + + print(">>> Performance gather_kv_b_proj_bf16_weight:") + print( + f">>> batch {batch_size}, block_size {block_size}, tp_k_head_num {tp_k_head_num}, " + f"kv_c_dim {kv_c_dim}, qk_nope_head_dim {qk_nope_head_dim}, kv_length {avg_kv_length}, " + f"scale_mode {scale_mode}\n" + f">>> elapsed={elapsed_us:.2f}us, TFLOPS={tflops:.2f}" + ) + + +@pytest.mark.parametrize( + "batch_size, block_size, k_buffer_type, avg_kv_length, qk_nope_head_dim, v_head_dim, scale_mode", + [ + # GLM-5 dims (192/256), per-row scale, tp=4 → 32 heads + (4, 1, dtypes.fp8, 512, 192, 256, "per_row"), + (8, 16, dtypes.fp8, 1024, 192, 256, "per_row"), + (4, 1, torch.bfloat16, 512, 192, 256, "per_row"), + # GLM-5 dims, bf16 weight + per-row all-ones scale + (4, 1, dtypes.fp8, 512, 192, 256, "bf16_weight"), + (8, 16, torch.bfloat16, 1024, 192, 256, "bf16_weight"), + # Symmetric (DeepSeek-like) as sanity check + (4, 1, dtypes.fp8, 512, 128, 128, "per_row"), + (4, 1, dtypes.fp8, 512, 128, 128, "bf16_weight"), + ], +) +def test_gather_kv_b_proj_asymmetric_dims( + batch_size, + block_size, + k_buffer_type, + avg_kv_length, + qk_nope_head_dim, + v_head_dim, + scale_mode, +): + """Test gather_kv_b_proj with qk_nope_head_dim != v_head_dim (e.g. GLM-5: 192/256).""" + torch.manual_seed(0) + random.seed(0) + kv_c_dim = 512 + kv_pe_dim = 64 + num_tp = 4 + tp_k_head_num = 128 // num_tp + weight_preshuffle = True + device = "cuda" + weight_n = tp_k_head_num * (qk_nope_head_dim + v_head_dim) + + ( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + context_lens, + num_block, + ) = _make_kv_test_data( + batch_size, + block_size, + avg_kv_length, + kv_c_dim, + kv_pe_dim, + k_buffer_type, + device, + ) + + if scale_mode == "bf16_weight": + weight_dtype = torch.bfloat16 + kv_proj_scale = torch.ones(weight_n, device=device, dtype=torch.float32) + else: + weight_dtype = dtypes.fp8 + kv_proj_scale = torch.randn( + (weight_n, 1), device=device, dtype=torch.float32 + ).abs() + + kv_proj_weight = torch.randn( + (weight_n, kv_c_dim), + device=device, + dtype=torch.float32, + ).to(weight_dtype) + + k_ref, v_ref = ref_gather_kv_b_proj( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, + ) + + total_kv = kv_prefix_sum_context_lens[-1].item() + k_prefix = torch.zeros( + (total_kv, tp_k_head_num * (qk_nope_head_dim + kv_pe_dim)), + device=device, + dtype=torch.bfloat16, + ) + v_prefix = torch.zeros( + (total_kv, tp_k_head_num * v_head_dim), + device=device, + dtype=torch.bfloat16, + ) + + if weight_preshuffle: + kv_proj_weight = shuffle_weight(kv_proj_weight) + + gather_kv_b_proj( + k_buffer, + k_scale, + kv_indptr, + kv_indices, + kv_prefix_sum_context_lens, + kv_proj_weight, + kv_proj_scale, + k_prefix.view(-1, tp_k_head_num, qk_nope_head_dim + kv_pe_dim), + v_prefix.view(-1, tp_k_head_num, v_head_dim), + weight_preshuffle=weight_preshuffle, + ) + + checkAllclose(k_ref, k_prefix, atol=1e-2, rtol=1e-2) + checkAllclose(v_ref, v_prefix, atol=1e-2, rtol=1e-2) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-B", "--batch", type=int, default=16, help="Batch size.")