Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aiter/ops/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def indexer_k_quant_and_cache(
slot_mapping: Tensor,
quant_block_size: int,
scale_fmt: str,
preshuffle: bool = False,
) -> None: ...


Expand All @@ -114,6 +115,7 @@ def cp_gather_indexer_k_quant_cache(
dst_scale: Tensor,
block_table: Tensor,
cu_seq_lens: Tensor,
preshuffle: bool = False,
) -> None: ...


Expand Down
6 changes: 4 additions & 2 deletions csrc/include/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,16 @@ void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
const std::string& scale_fmt,
bool preshuffle = false);

void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
const torch::Tensor& cu_seq_lens, // [batch_size + 1]
bool preshuffle = false);

void fused_qk_rope_concat_and_cache_mla(
torch::Tensor& q_nope, // [num_tokens, num_heads, qk_lora_rank]
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,16 @@ namespace py = pybind11;
py::arg("kv_cache"), \
py::arg("slot_mapping"), \
py::arg("quant_block_size"), \
py::arg("scale_fmt")); \
py::arg("scale_fmt"), \
py::arg("preshuffle") = false); \
m.def("cp_gather_indexer_k_quant_cache", \
&aiter::cp_gather_indexer_k_quant_cache, \
py::arg("kv_cache"), \
py::arg("dst_k"), \
py::arg("dst_scale"), \
py::arg("block_table"), \
py::arg("cu_seq_lens")); \
py::arg("cu_seq_lens"), \
py::arg("preshuffle") = false); \
m.def("fused_qk_rope_concat_and_cache_mla", \
&aiter::fused_qk_rope_concat_and_cache_mla, \
"fused_qk_rope_concat_and_cache_mla(" \
Expand Down
98 changes: 81 additions & 17 deletions csrc/kernels/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
const bool use_ue8m0, // use ue8m0 scale format
const bool preshuffle // use MFMA 16x16 preshuffled layout
)
{
const int quant_block_per_head = head_dim / quant_block_size;
Expand Down Expand Up @@ -1211,16 +1212,33 @@ __global__ void indexer_k_quant_and_cache_kernel(
scale = exp2f(ceilf(log2f(scale)));
}

const int64_t dst_offset =
block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx;
int64_t dst_offset;
if(preshuffle)
{
// Preshuffled layout for MFMA 16x16 tile.
// Works for any cache_block_size and head_dim that are multiples of 16.
// A paged block is split into (cache_block_size / 16) token groups; each group
// contains (head_dim / 16) contiguous 16x16 tiles laid out row-major within tile.
constexpr int TILE = 16;
const int token_tile_id = block_offset / TILE;
const int token_in_tile = block_offset % TILE;
const int col_tile_id = head_dim_idx / TILE;
const int col_in_tile = head_dim_idx % TILE;
dst_offset = block_idx * cache_block_size * cache_stride
+ token_tile_id * (TILE * head_dim)
+ col_tile_id * (TILE * TILE)
+ token_in_tile * TILE
+ col_in_tile;
}
else
{
dst_offset =
block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx;
}

// for(int i = 0; i < VEC_SIZE; i++)
// {
// kv_cache[dst_offset + i] =
// opus::cast<cache_t>(static_cast<float>(k_val[i]) / scale);
// }
if(threadIdx.x == 0)
{
// Scale layout is unchanged regardless of preshuffle
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride + cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
Expand Down Expand Up @@ -1248,7 +1266,8 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const int64_t cache_block_size, // num_tokens for each block in kv_cache
const int num_blocks, // number of blocks
const int num_tokens, // number of tokens
const int quant_block_size // quantization block size
const int quant_block_size, // quantization block size
const bool preshuffle // source uses MFMA 16x16 preshuffled layout
)
{
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
Expand Down Expand Up @@ -1278,14 +1297,36 @@ __global__ void cp_gather_indexer_k_quant_cache_kernel(
const int block_idx =
block_table[batch_idx[threadIdx.y] * num_blocks + inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset = (inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
const int64_t block_offset = inbatch_seq_idx % cache_block_size;
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;

int64_t src_inblock_offset;
if(preshuffle)
{
// Preshuffled layout: reverse the MFMA 16x16 tile mapping.
// Works for any cache_block_size and head_dim that are multiples of 16.
constexpr int TILE = 16;
const int token_tile_id = block_offset / TILE;
const int token_in_tile = block_offset % TILE;
const int col_tile_id = head_idx / TILE;
const int col_in_tile = head_idx % TILE;
src_inblock_offset = src_block_offset
+ token_tile_id * (TILE * head_dim)
+ col_tile_id * (TILE * TILE)
+ token_in_tile * TILE
+ col_in_tile;
}
else
{
src_inblock_offset = src_block_offset + block_offset * head_dim + head_idx;
}

reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
if(threadIdx.x == 0)
{
// Scale layout is unchanged regardless of preshuffle
const int64_t cache_inblock_offset = block_offset * head_dim + head_idx;
const int64_t src_scale_offset = src_block_offset + cache_block_size * head_dim +
cache_inblock_offset * 4 / quant_block_size;
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
Expand Down Expand Up @@ -2896,9 +2937,9 @@ void reshape_and_cache_flash(
quant_block_size, \
cache_block_size, \
cache_stride, \
use_ue8m0);
use_ue8m0, \
do_preshuffle);

// Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
aiter::cp_gather_indexer_k_quant_cache_kernel<8, BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
Expand All @@ -2918,7 +2959,8 @@ void reshape_and_cache_flash(
kv_cache.size(1), \
block_table.size(1), \
num_tokens, \
quant_block_size);
quant_block_size, \
do_preshuffle);

#define CALL_FUSED_QK_ROPE_CONCAT_AND_CACHE_MLA_OPT(KV_T, CACHE_T, QUERY_T, KV_DTYPE, Q_DTYPE, VEC_SIZE) \
aiter::fuse_qk_rope_concat_and_cache_mla_per_head_kernel<KV_T, CACHE_T, QUERY_T, KV_DTYPE, Q_DTYPE, VEC_SIZE> \
Expand Down Expand Up @@ -3296,18 +3338,29 @@ void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt)
const std::string& scale_fmt,
bool preshuffle)
{
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
bool do_preshuffle = preshuffle;

TORCH_CHECK(k.device() == kv_cache.device(), "k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size");
if(preshuffle)
{
TORCH_CHECK(cache_block_size % 16 == 0,
"preshuffle requires cache_block_size to be a multiple of 16, got ",
cache_block_size);
TORCH_CHECK(head_dim % 16 == 0,
"preshuffle requires head_dim to be a multiple of 16, got ",
head_dim);
}

int quant_blocks = num_tokens * head_dim / quant_block_size;
const int vec_size = 16;
Expand All @@ -3327,13 +3380,14 @@ void cp_gather_indexer_k_quant_cache(
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size] float
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens // [batch_size + 1]
)
const torch::Tensor& cu_seq_lens, // [batch_size + 1]
bool preshuffle)
{
int batch_size = block_table.size(0);
int num_tokens = dst_k.size(0);
int head_dim = dst_k.size(1);
int quant_block_size = head_dim / (dst_scale.size(1) * dst_scale.itemsize() / 4);
bool do_preshuffle = preshuffle;

TORCH_CHECK(kv_cache.device() == dst_k.device(),
"kv_cache and dst_k must be on the same device");
Expand All @@ -3344,6 +3398,16 @@ void cp_gather_indexer_k_quant_cache(
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
"kv_cache and cu_seq_lens must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size");
if(preshuffle)
{
int cache_block_size = kv_cache.size(1);
TORCH_CHECK(cache_block_size % 16 == 0,
"preshuffle requires cache_block_size to be a multiple of 16, got ",
cache_block_size);
TORCH_CHECK(head_dim % 16 == 0,
"preshuffle requires head_dim to be a multiple of 16, got ",
head_dim);
}

constexpr int vec_size = 16;
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_cache));
Expand Down
Loading
Loading