diff --git a/aiter/ops/cache.py b/aiter/ops/cache.py index 56293a7b77..9059a831e2 100644 --- a/aiter/ops/cache.py +++ b/aiter/ops/cache.py @@ -104,6 +104,7 @@ def indexer_k_quant_and_cache( slot_mapping: Tensor, quant_block_size: int, scale_fmt: str, + preshuffle: bool = False, ) -> None: ... @@ -114,6 +115,7 @@ def cp_gather_indexer_k_quant_cache( dst_scale: Tensor, block_table: Tensor, cu_seq_lens: Tensor, + preshuffle: bool = False, ) -> None: ... diff --git a/csrc/include/cache.h b/csrc/include/cache.h index 6650d26c7a..7939800a00 100644 --- a/csrc/include/cache.h +++ b/csrc/include/cache.h @@ -77,14 +77,16 @@ void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] torch::Tensor& slot_mapping, // [num_tokens] int64_t quant_block_size, // quantization block size - const std::string& scale_fmt); + const std::string& scale_fmt, + bool preshuffle = false); void cp_gather_indexer_k_quant_cache( const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens); // [batch_size + 1] + const torch::Tensor& cu_seq_lens, // [batch_size + 1] + bool preshuffle = false); void fused_qk_rope_concat_and_cache_mla( torch::Tensor& q_nope, // [num_tokens, num_heads, qk_lora_rank] diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index add0c0ad2e..35cd1f9138 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -333,14 +333,16 @@ namespace py = pybind11; py::arg("kv_cache"), \ py::arg("slot_mapping"), \ py::arg("quant_block_size"), \ - py::arg("scale_fmt")); \ + py::arg("scale_fmt"), \ + py::arg("preshuffle") = false); \ m.def("cp_gather_indexer_k_quant_cache", \ &aiter::cp_gather_indexer_k_quant_cache, \ py::arg("kv_cache"), \ py::arg("dst_k"), \ py::arg("dst_scale"), \ py::arg("block_table"), \ - py::arg("cu_seq_lens")); \ + py::arg("cu_seq_lens"), \ + py::arg("preshuffle") = false); \ m.def("fused_qk_rope_concat_and_cache_mla", \ &aiter::fused_qk_rope_concat_and_cache_mla, \ "fused_qk_rope_concat_and_cache_mla(" \ diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index 540f434a29..07b3fa5f32 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -1157,7 +1157,8 @@ __global__ void indexer_k_quant_and_cache_kernel( const int quant_block_size, // quantization block size const int cache_block_size, // cache block size const int cache_stride, // stride for each token in kv_cache - const bool use_ue8m0 // use ue8m0 scale format + const bool use_ue8m0, // use ue8m0 scale format + const bool preshuffle // use MFMA 16x16 preshuffled layout ) { const int quant_block_per_head = head_dim / quant_block_size; @@ -1211,16 +1212,33 @@ __global__ void indexer_k_quant_and_cache_kernel( scale = exp2f(ceilf(log2f(scale))); } - const int64_t dst_offset = - block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; + int64_t dst_offset; + if(preshuffle) + { + // Preshuffled layout for MFMA 16x16 tile. + // Works for any cache_block_size and head_dim that are multiples of 16. + // A paged block is split into (cache_block_size / 16) token groups; each group + // contains (head_dim / 16) contiguous 16x16 tiles laid out row-major within tile. + constexpr int TILE = 16; + const int token_tile_id = block_offset / TILE; + const int token_in_tile = block_offset % TILE; + const int col_tile_id = head_dim_idx / TILE; + const int col_in_tile = head_dim_idx % TILE; + dst_offset = block_idx * cache_block_size * cache_stride + + token_tile_id * (TILE * head_dim) + + col_tile_id * (TILE * TILE) + + token_in_tile * TILE + + col_in_tile; + } + else + { + dst_offset = + block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; + } - // for(int i = 0; i < VEC_SIZE; i++) - // { - // kv_cache[dst_offset + i] = - // opus::cast(static_cast(k_val[i]) / scale); - // } if(threadIdx.x == 0) { + // Scale layout is unchanged regardless of preshuffle const int64_t dst_scale_idx = block_idx * cache_block_size * cache_stride + cache_block_size * head_dim + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; @@ -1248,7 +1266,8 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int64_t cache_block_size, // num_tokens for each block in kv_cache const int num_blocks, // number of blocks const int num_tokens, // number of tokens - const int quant_block_size // quantization block size + const int quant_block_size, // quantization block size + const bool preshuffle // source uses MFMA 16x16 preshuffled layout ) { constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); @@ -1278,14 +1297,36 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel( const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks + inbatch_seq_idx / cache_block_size]; const int64_t src_block_offset = block_idx * block_stride; - const int64_t cache_inblock_offset = (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; - const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t block_offset = inbatch_seq_idx % cache_block_size; const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + int64_t src_inblock_offset; + if(preshuffle) + { + // Preshuffled layout: reverse the MFMA 16x16 tile mapping. + // Works for any cache_block_size and head_dim that are multiples of 16. + constexpr int TILE = 16; + const int token_tile_id = block_offset / TILE; + const int token_in_tile = block_offset % TILE; + const int col_tile_id = head_idx / TILE; + const int col_in_tile = head_idx % TILE; + src_inblock_offset = src_block_offset + + token_tile_id * (TILE * head_dim) + + col_tile_id * (TILE * TILE) + + token_in_tile * TILE + + col_in_tile; + } + else + { + src_inblock_offset = src_block_offset + block_offset * head_dim + head_idx; + } + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; if(threadIdx.x == 0) { + // Scale layout is unchanged regardless of preshuffle + const int64_t cache_inblock_offset = block_offset * head_dim + head_idx; const int64_t src_scale_offset = src_block_offset + cache_block_size * head_dim + cache_inblock_offset * 4 / quant_block_size; reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = @@ -2896,9 +2937,9 @@ void reshape_and_cache_flash( quant_block_size, \ cache_block_size, \ cache_stride, \ - use_ue8m0); + use_ue8m0, \ + do_preshuffle); -// Macro to dispatch the kernel based on the data amount. #define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ aiter::cp_gather_indexer_k_quant_cache_kernel<8, BLOCK_Y_SIZE> \ << \ @@ -3296,18 +3338,29 @@ void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] torch::Tensor& slot_mapping, // [num_tokens] int64_t quant_block_size, // quantization block size - const std::string& scale_fmt) + const std::string& scale_fmt, + bool preshuffle) { int num_tokens = k.size(0); int head_dim = k.size(1); int cache_block_size = kv_cache.size(1); int cache_stride = kv_cache.size(2); bool use_ue8m0 = scale_fmt == "ue8m0"; + bool do_preshuffle = preshuffle; TORCH_CHECK(k.device() == kv_cache.device(), "k and kv_cache must be on the same device"); TORCH_CHECK(k.device() == slot_mapping.device(), "k and slot_mapping must be on the same device"); TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size"); + if(preshuffle) + { + TORCH_CHECK(cache_block_size % 16 == 0, + "preshuffle requires cache_block_size to be a multiple of 16, got ", + cache_block_size); + TORCH_CHECK(head_dim % 16 == 0, + "preshuffle requires head_dim to be a multiple of 16, got ", + head_dim); + } int quant_blocks = num_tokens * head_dim / quant_block_size; const int vec_size = 16; @@ -3327,13 +3380,14 @@ void cp_gather_indexer_k_quant_cache( torch::Tensor& dst_k, // [num_tokens, head_dim] torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size] float const torch::Tensor& block_table, // [batch_size, num_blocks] - const torch::Tensor& cu_seq_lens // [batch_size + 1] -) + const torch::Tensor& cu_seq_lens, // [batch_size + 1] + bool preshuffle) { int batch_size = block_table.size(0); int num_tokens = dst_k.size(0); int head_dim = dst_k.size(1); int quant_block_size = head_dim / (dst_scale.size(1) * dst_scale.itemsize() / 4); + bool do_preshuffle = preshuffle; TORCH_CHECK(kv_cache.device() == dst_k.device(), "kv_cache and dst_k must be on the same device"); @@ -3344,6 +3398,16 @@ void cp_gather_indexer_k_quant_cache( TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), "kv_cache and cu_seq_lens must be on the same device"); TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size"); + if(preshuffle) + { + int cache_block_size = kv_cache.size(1); + TORCH_CHECK(cache_block_size % 16 == 0, + "preshuffle requires cache_block_size to be a multiple of 16, got ", + cache_block_size); + TORCH_CHECK(head_dim % 16 == 0, + "preshuffle requires head_dim to be a multiple of 16, got ", + head_dim); + } constexpr int vec_size = 16; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_cache)); diff --git a/op_tests/test_indexer_k_quant_and_cache.py b/op_tests/test_indexer_k_quant_and_cache.py index a238648a5c..8d7fbf2b70 100644 --- a/op_tests/test_indexer_k_quant_and_cache.py +++ b/op_tests/test_indexer_k_quant_and_cache.py @@ -4,50 +4,116 @@ import torch import aiter from aiter.test_common import checkAllclose, run_perftest, benchmark -from aiter import dtypes -from aiter import pertoken_quant, dtypes, indexer_k_quant_and_cache +from aiter import ( + pertoken_quant, + dtypes, + indexer_k_quant_and_cache, + cp_gather_indexer_k_quant_cache, +) import argparse import pandas as pd MAX_TOKEN_SUPPORTED = 16384 +TILE = 16 # MFMA 16x16 tile size used by the preshuffle layout torch.set_default_device("cuda") -def run_torch(k, kv_cache, slot_mapping, quant_block_size, scale_fmt): - num_token, head_dim = k.shape - block_size = kv_cache.shape[1] +def _split_k_scale(kv_cache, head_dim): + """Split a kv_cache tensor into its K-data bytes and scale float32 regions. + + kv_cache shape: [block_num, block_size, head_dim + head_dim/quant_block_size * 4] (fp8). + Both write and gather kernels treat each paged block as a block-major packed + byte buffer: first `block_size*head_dim` bytes for K, then the rest for scales. + """ + block_num, block_size, cache_stride = kv_cache.shape + flat = kv_cache.view(block_num, block_size * cache_stride) + k_bytes = flat[:, : block_size * head_dim].contiguous() + scale_region = flat[:, block_size * head_dim :].contiguous() + return k_bytes, scale_region.view(torch.float32) + + +def _write_block_preshuffle(block_flat, k_fp8_row, block_offset, head_dim): + """Write one token's FP8 K values into a block using the MFMA 16x16 preshuffle layout.""" + token_tile_id = block_offset // TILE + token_in_tile = block_offset % TILE + for col_tile_id in range(head_dim // TILE): + col_base = col_tile_id * TILE + tile_base = ( + token_tile_id * TILE * head_dim + + col_tile_id * TILE * TILE + + token_in_tile * TILE + ) + block_flat[tile_base : tile_base + TILE] = k_fp8_row[col_base : col_base + TILE] + + +def _compute_ref_scale(k_flat_quant_blocks, scale_fmt): + """Replicate the kernel's fp32 scale computation exactly. + + The kernel works in fp32 throughout; doing the ue8m0 log2/ceil in bf16 + loses precision near power-of-two boundaries and can make the reference + scale differ from the kernel's by a factor of 2. Cast to fp32 first. + """ per_token_amax, _ = torch.max( - input=torch.abs(k.view(-1, quant_block_size)), dim=-1, keepdim=True + input=torch.abs(k_flat_quant_blocks.to(torch.float32)), dim=-1, keepdim=True ) scale = per_token_amax / torch.finfo(dtypes.fp8).max if scale_fmt == "ue8m0": scale = torch.pow(2.0, torch.ceil(torch.log2(scale))) + return scale + + +def run_torch(k, kv_cache, slot_mapping, quant_block_size, scale_fmt, preshuffle=False): + num_token, head_dim = k.shape + block_num, block_size, cache_stride = kv_cache.shape + scale = _compute_ref_scale(k.view(-1, quant_block_size), scale_fmt) k_fp8, scale = pertoken_quant( k.view(-1, quant_block_size), quant_dtype=dtypes.fp8, scale=scale ) k_fp8 = k_fp8.view(num_token, head_dim) + n_scale_bytes = head_dim // quant_block_size * 4 + kv_flat = kv_cache.view(block_num, block_size * cache_stride) for i in range(num_token): slot = slot_mapping[i].item() - blockId = slot // block_size + if slot < 0: + continue + block_id = slot // block_size block_offset = slot % block_size - kv_cache[blockId, block_offset, :head_dim] = k_fp8[i] - kv_cache[blockId, block_offset, head_dim:] = scale[i].view(dtypes.fp8) + block_flat = kv_flat[block_id] + if preshuffle: + _write_block_preshuffle(block_flat, k_fp8[i], block_offset, head_dim) + else: + # Block-major packed layout to match the C++ kernel: + # [K slot 0 | K slot 1 | ... | K slot (B-1) | Scale slot 0 | ... | Scale slot (B-1)] + k_offset = block_offset * head_dim + block_flat[k_offset : k_offset + head_dim] = k_fp8[i] + scale_offset = block_size * head_dim + block_offset * n_scale_bytes + block_flat[scale_offset : scale_offset + n_scale_bytes] = ( + scale[i].view(dtypes.fp8).reshape(-1) + ) @benchmark() def test_indexer_k_quant_and_cache( - num_token, block_size, quant_block_size, head_dim=128 + num_token, block_size, quant_block_size, head_dim=128, preshuffle=False ): assert ( num_token <= MAX_TOKEN_SUPPORTED ), f"test only support max_token={MAX_TOKEN_SUPPORTED}" + if preshuffle: + assert block_size % TILE == 0 and head_dim % TILE == 0, ( + f"preshuffle requires block_size and head_dim multiples of {TILE}, " + f"got block_size={block_size}, head_dim={head_dim}" + ) block_num = (num_token + block_size - 1) // block_size k = torch.randn((num_token, head_dim), dtype=dtypes.bf16) slot_mapping = torch.arange(0, num_token, 1, dtype=torch.int64) scale_fmt = "ue8m0" - kv_cache = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) - run_torch(k, kv_cache, slot_mapping, quant_block_size, scale_fmt) - kv_cache2 = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + # Zero-init so unwritten padding slots (if any) match between ref and kernel. + kv_cache = torch.zeros((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + run_torch( + k, kv_cache, slot_mapping, quant_block_size, scale_fmt, preshuffle=preshuffle + ) + kv_cache2 = torch.zeros((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) _, us = run_perftest( indexer_k_quant_and_cache, k, @@ -55,37 +121,98 @@ def test_indexer_k_quant_and_cache( slot_mapping, quant_block_size, scale_fmt, + preshuffle, ) - err = checkAllclose( - kv_cache.view(-1, head_dim + 4)[:num_token].to(torch.float), - kv_cache2.view(-1, head_dim + 4)[:num_token].to(torch.float), - ) - # scale = kv_cache[:, :, head_dim:].view(torch.float) - # scale2 = kv_cache2[:, :, head_dim:].view(torch.float) - ret = {"aiter us": us, "aiter err": err} - try: - from vllm import _custom_ops as ops - - kv_cache3 = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) - _, us2 = run_perftest( - ops.indexer_k_quant_and_cache, - k, - kv_cache3, - slot_mapping, - quant_block_size, - scale_fmt, - ) - err2 = checkAllclose( - kv_cache.view(-1, head_dim + 4)[:num_token].to(torch.float), - kv_cache3.view(-1, head_dim + 4)[:num_token].to(torch.float), - ) - ret.update({"vllm us": us2, "vllm err": err2}) - except Exception: - # Ignore all exceptions here because vllm._custom_ops is optional and may not be available. - pass + # Compare K bytes (as FP8) and scale float32 regions separately to avoid the + # FP8-bit-reinterpretation artifact when a float32 scale is viewed as 4 FP8 bytes. + k_ref, s_ref = _split_k_scale(kv_cache, head_dim) + k_got, s_got = _split_k_scale(kv_cache2, head_dim) + err_k = checkAllclose(k_ref.to(torch.float), k_got.to(torch.float)) + err_s = checkAllclose(s_ref, s_got) + ret = {"aiter us": us, "aiter k_err": err_k, "aiter s_err": err_s} + if not preshuffle: + # vllm reference op does not support preshuffle mode. + try: + from vllm import _custom_ops as ops + + kv_cache3 = torch.zeros( + (block_num, block_size, head_dim + 4), dtype=dtypes.fp8 + ) + _, us2 = run_perftest( + ops.indexer_k_quant_and_cache, + k, + kv_cache3, + slot_mapping, + quant_block_size, + scale_fmt, + ) + k_vllm, s_vllm = _split_k_scale(kv_cache3, head_dim) + err2_k = checkAllclose(k_ref.to(torch.float), k_vllm.to(torch.float)) + err2_s = checkAllclose(s_ref, s_vllm) + ret.update({"vllm us": us2, "vllm k_err": err2_k, "vllm s_err": err2_s}) + except Exception: + # Ignore all exceptions here because vllm._custom_ops is optional and may not be available. + pass return ret +@benchmark() +def test_cp_gather_indexer_k_quant_cache( + num_token, block_size, quant_block_size, head_dim=128, preshuffle=False +): + """Round-trip: write with indexer_k_quant_and_cache(preshuffle=P), + read back with cp_gather_indexer_k_quant_cache(preshuffle=P), and compare + to the direct pertoken-quant reference. Verifies write+gather layouts are + internally consistent and match the expected quantized values.""" + assert ( + num_token <= MAX_TOKEN_SUPPORTED + ), f"test only support max_token={MAX_TOKEN_SUPPORTED}" + if preshuffle: + assert block_size % TILE == 0 and head_dim % TILE == 0, ( + f"preshuffle requires block_size and head_dim multiples of {TILE}, " + f"got block_size={block_size}, head_dim={head_dim}" + ) + block_num = (num_token + block_size - 1) // block_size + k = torch.randn((num_token, head_dim), dtype=dtypes.bf16) + slot_mapping = torch.arange(0, num_token, 1, dtype=torch.int64) + scale_fmt = "ue8m0" + + # Reference quantized values (layout-agnostic). Use the same fp32 scale + # helper as run_torch so we match the kernel's fp32 precision exactly. + ref_scale = _compute_ref_scale(k.view(-1, quant_block_size), scale_fmt) + ref_k_fp8, ref_scale = pertoken_quant( + k.view(-1, quant_block_size), quant_dtype=dtypes.fp8, scale=ref_scale + ) + ref_k_fp8 = ref_k_fp8.view(num_token, head_dim) + ref_scale = ref_scale.view(num_token, head_dim // quant_block_size) + + # Write phase. + kv_cache = torch.zeros((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + indexer_k_quant_and_cache( + k, kv_cache, slot_mapping, quant_block_size, scale_fmt, preshuffle + ) + + # Gather phase: batch_size=1, linear block_table covering every slot in order. + block_table = torch.arange(0, block_num, dtype=torch.int32).view(1, -1) + cu_seq_lens = torch.tensor([0, num_token], dtype=torch.int32) + dst_k = torch.empty((num_token, head_dim), dtype=dtypes.fp8) + dst_scale = torch.empty( + (num_token, head_dim // quant_block_size), dtype=torch.float32 + ) + _, us = run_perftest( + cp_gather_indexer_k_quant_cache, + kv_cache, + dst_k, + dst_scale, + block_table, + cu_seq_lens, + preshuffle, + ) + err_k = checkAllclose(dst_k.to(torch.float), ref_k_fp8.to(torch.float)) + err_s = checkAllclose(dst_scale, ref_scale) + return {"aiter us": us, "k err": err_k, "scale err": err_s} + + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="Test indexer_k_quant_and_cache.", @@ -105,13 +232,43 @@ def test_indexer_k_quant_and_cache( default=[1], help="""block_size, default: 1""", ) +parser.add_argument( + "-p", + "--preshuffle", + action="store_true", + help="""Also run preshuffle=True. Requires block_size and head_dim to be multiples of 16; combos that don't meet this are silently skipped.""", +) +parser.add_argument( + "-g", + "--gather", + action="store_true", + help="""Also run cp_gather_indexer_k_quant_cache round-trip tests.""", +) args = parser.parse_args() + +preshuffle_modes = [False] + ([True] if args.preshuffle else []) + df = [] +gather_df = [] for m in args.m: for block_size in args.block_size: - ret = test_indexer_k_quant_and_cache(m, block_size, 128, 128) - df.append(ret) + for preshuffle in preshuffle_modes: + if preshuffle and (block_size % TILE != 0): + continue + ret = test_indexer_k_quant_and_cache(m, block_size, 128, 128, preshuffle) + df.append(ret) + if args.gather: + gret = test_cp_gather_indexer_k_quant_cache( + m, block_size, 128, 128, preshuffle + ) + gather_df.append(gret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) aiter.logger.info("indexer_k_quant_and_cache summary (markdown):\n%s", df_md) +if args.gather: + gather_df = pd.DataFrame(gather_df) + aiter.logger.info( + "cp_gather_indexer_k_quant_cache round-trip summary (markdown):\n%s", + gather_df.to_markdown(index=False), + )