diff --git a/aiter/ops/cache.py b/aiter/ops/cache.py index 788a4203bf..849fddd4a5 100644 --- a/aiter/ops/cache.py +++ b/aiter/ops/cache.py @@ -95,3 +95,23 @@ def concat_and_cache_mla( kv_cache_dtype: str, scale: Tensor, ) -> None: ... + + +@compile_ops("module_cache") +def indexer_k_quant_and_cache( + k: Tensor, + kv_cache: Tensor, + slot_mapping: Tensor, + quant_block_size: int, + scale_fmt: str, +) -> None: ... + + +@compile_ops("module_cache") +def cp_gather_indexer_k_quant_cache( + kv_cache: Tensor, + dst_k: Tensor, + dst_scale: Tensor, + block_table: Tensor, + cu_seq_lens: Tensor, +) -> None: ... diff --git a/csrc/include/cache.h b/csrc/include/cache.h index 7b4a786616..c118b82a1c 100644 --- a/csrc/include/cache.h +++ b/csrc/include/cache.h @@ -65,12 +65,24 @@ void reshape_and_cache_with_block_quant_for_asm_pa( const bool asm_layout, const int ori_block_size = 128); -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale); +void concat_and_cache_mla(torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, + torch::Tensor& scale); +void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt); + +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4] + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens); // [batch_size + 1] } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index a23e27b354..eb3f7c9449 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -260,7 +260,25 @@ py::arg("kv_cache"), \ py::arg("slot_mapping"), \ py::arg("kv_cache_dtype"), \ - py::arg("scale")); + py::arg("scale")); \ + m.def("indexer_k_quant_and_cache", \ + &aiter::indexer_k_quant_and_cache, \ + "indexer_k_quant_and_cache(Tensor k, Tensor kv_cache," \ + " Tensor slot_mapping," \ + " int64_t quant_block_size," \ + " std::string& scale_fmt) -> ()", \ + py::arg("k"), \ + py::arg("kv_cache"), \ + py::arg("slot_mapping"), \ + py::arg("quant_block_size"), \ + py::arg("scale_fmt")); \ + m.def("cp_gather_indexer_k_quant_cache", \ + &aiter::cp_gather_indexer_k_quant_cache, \ + py::arg("kv_cache"), \ + py::arg("dst_k"), \ + py::arg("dst_scale"), \ + py::arg("block_table"), \ + py::arg("cu_seq_lens")); #define CUSTOM_ALL_REDUCE_PYBIND \ m.def("init_custom_ar", \ diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index 5283298369..66a73249f2 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -55,7 +55,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& bl char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard( + src_device.is_cuda() ? src_device : dst_device); const hipStream_t stream = at::hip::getCurrentHIPStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. const int64_t num_blocks = block_mapping.size(0); @@ -975,140 +976,313 @@ __global__ void reshape_and_cache_with_block_quant_kernel_for_asmpa( } template __global__ void concat_and_cache_mla_kernel( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size, // - const float* scale // -) { - const int64_t token_idx = blockIdx.x; - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - return; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - const float inverted_kscale = 1.0f / *scale; - auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, - int src_stride, int dst_stride, int size, int offset) { - for (int i = threadIdx.x; i < size; i += blockDim.x) { - const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = - block_idx * block_stride + block_offset * entry_stride + i + offset; - if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - dst[dst_idx] = src[src_idx]; - } else { - dst[dst_idx]= ck_tile::type_convert( - ck_tile::type_convert(src[src_idx]) * inverted_kscale); - } - } - }; - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) +{ + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if(slot_idx < 0) + { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const float inverted_kscale = 1.0f / *scale; + auto copy = [&](const scalar_t* __restrict__ src, + cache_t* __restrict__ dst, + int src_stride, + int dst_stride, + int size, + int offset) { + for(int i = threadIdx.x; i < size; i += blockDim.x) + { + const int64_t src_idx = token_idx * src_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; + if constexpr(kv_dt == vllm::Fp8KVCacheDataType::kAuto) + { + dst[dst_idx] = src[src_idx]; + } + else + { + dst[dst_idx] = ck_tile::type_convert( + ck_tile::type_convert(src[src_idx]) * inverted_kscale); + } + } + }; + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } template __global__ void concat_and_cache_mla_opt_kernel( - const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] - const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] - cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank - // + pe_dim)] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, // - const int entry_stride, // - const int kv_c_stride, // - const int k_pe_stride, // - const int kv_lora_rank, // - const int pe_dim, // - const int block_size, // - const float* scale // -) { - const int64_t token_idx = blockIdx.x; - const int64_t slot_idx = slot_mapping[token_idx]; - // NOTE: slot_idx can be -1 if the token is padded - if (slot_idx < 0) { - return; - } - const int64_t block_idx = slot_idx / block_size; - const int64_t block_offset = slot_idx % block_size; - const float inverted_kscale = 1.0f / *scale; - static constexpr int32_t vec_size_i = std::is_same_v ? 4 : 8; - static constexpr int32_t vec_size_o = vec_size_i; - using vec_i = ck_tile::vec_t; - static constexpr int32_t ooba_i = 4 / sizeof(scalar_t); - static constexpr int32_t ooba_o = 4 / sizeof(cache_t); - auto out_offset = block_idx * block_stride + block_offset * entry_stride; - auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst, - int src_stride, int dst_stride, int size, int offset) { - const int32_t oob_i = (size + ooba_i - 1) / ooba_i * ooba_i; - const int32_t oob_o = (size + ooba_o - 1) / ooba_o * ooba_o; - auto const* ptr_i = reinterpret_cast(src + token_idx * src_stride); - auto* ptr_o = reinterpret_cast(dst + out_offset + offset); - auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); - buffer_i.init_raw(); - auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); - buffer_o.init_raw(); - - // double load core loop start - const int32_t num_vecs = (size + vec_size_i - 1) / vec_size_i; - vec_i vec_nxt; - vec_i vec_cur; - - size_t vec_idx = threadIdx.x; - size_t vec_stride = blockDim.x; - if (vec_idx < num_vecs) - { - vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); - } - for (vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) - { - vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); - if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - vec_cur.template get_as()); - } else { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_kscale) - .template get_as()); + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) +{ + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if(slot_idx < 0) + { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const float inverted_kscale = 1.0f / *scale; + static constexpr int32_t vec_size_i = std::is_same_v ? 4 : 8; + static constexpr int32_t vec_size_o = vec_size_i; + using vec_i = ck_tile::vec_t; + static constexpr int32_t ooba_i = 4 / sizeof(scalar_t); + static constexpr int32_t ooba_o = 4 / sizeof(cache_t); + auto out_offset = block_idx * block_stride + block_offset * entry_stride; + auto copy = [&](const scalar_t* __restrict__ src, + cache_t* __restrict__ dst, + int src_stride, + int dst_stride, + int size, + int offset) { + const int32_t oob_i = (size + ooba_i - 1) / ooba_i * ooba_i; + const int32_t oob_o = (size + ooba_o - 1) / ooba_o * ooba_o; + auto const* ptr_i = reinterpret_cast(src + token_idx * src_stride); + auto* ptr_o = reinterpret_cast(dst + out_offset + offset); + auto buffer_i = + ck_tile::make_buffer_view(ptr_i, oob_i); + buffer_i.init_raw(); + auto buffer_o = + ck_tile::make_buffer_view(ptr_o, oob_o); + buffer_o.init_raw(); + + // double load core loop start + const int32_t num_vecs = (size + vec_size_i - 1) / vec_size_i; + vec_i vec_nxt; + vec_i vec_cur; + + size_t vec_idx = threadIdx.x; + size_t vec_stride = blockDim.x; + if(vec_idx < num_vecs) + { + vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); + } + for(vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) + { + vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); + if constexpr(kv_dt == vllm::Fp8KVCacheDataType::kAuto) + { + buffer_o.template set((vec_idx - vec_stride) * vec_size_o, + 0, + true, + vec_cur.template get_as()); + } + else + { + buffer_o.template set( + (vec_idx - vec_stride) * vec_size_o, + 0, + true, + ck_tile::vec_convert(vec_cur, inverted_kscale) + .template get_as()); + } + vec_cur = vec_nxt; } - vec_cur = vec_nxt; - } - if (vec_idx - vec_stride < num_vecs) - { - if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - vec_cur.template get_as()); - } else { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_kscale) - .template get_as()); + if(vec_idx - vec_stride < num_vecs) + { + if constexpr(kv_dt == vllm::Fp8KVCacheDataType::kAuto) + { + buffer_o.template set((vec_idx - vec_stride) * vec_size_o, + 0, + true, + vec_cur.template get_as()); + } + else + { + buffer_o.template set( + (vec_idx - vec_stride) * vec_size_o, + 0, + true, + ck_tile::vec_convert(vec_cur, inverted_kscale) + .template get_as()); + } } + }; + + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); + copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); +} + +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int num_tokens, + const int head_dim, // dimension of each head + const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache + const bool use_ue8m0 // use ue8m0 scale format +) +{ + const int quant_block_per_head = head_dim / quant_block_size; + const int64_t token_idx = (blockIdx.x * BLOCK_Y_SIZE + threadIdx.y) / quant_block_per_head; + if(token_idx >= num_tokens) + return; + const int64_t slot_idx = slot_mapping[token_idx]; + const int head_dim_idx = + (blockIdx.x * BLOCK_Y_SIZE + threadIdx.y) % quant_block_per_head * quant_block_size + + threadIdx.x * VEC_SIZE; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; + using vec_i = ck_tile::vec_t; + using vec_o = ck_tile::vec_t; + + // NOTE: slot_idx can be -1 if the token is padded + if(slot_idx < 0 || (head_dim_idx >= head_dim)) + { + return; } - }; - copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); - copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); + vec_i k_val = + (reinterpret_cast(k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + float amax = 0.0f; + if constexpr(VEC_SIZE % 2 == 0) + { + for(int i = 0; i < VEC_SIZE; i += 2) + { + asm volatile("v_max3_f32 %0, %1, %2, %3\n" + : "=v"(amax) + : "v"(amax), + "v"(fabsf(ck_tile::type_convert(k_val[i]))), + "v"(fabsf(ck_tile::type_convert(k_val[i + 1])))); + } + } + else + { + for(int i = 0; i < VEC_SIZE; i++) + { + amax = fmaxf(amax, fabsf(ck_tile::type_convert(k_val[i]))); + } + } + + // Reduced amax + amax = multithread_reduce(amax, fmaxf, BLOCK_X_SIZE); + + float scale = + fmaxf(amax, 1e-4) / ck_tile::type_convert(ck_tile::numeric::max()); + if(use_ue8m0) + { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t dst_offset = + block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; + + // for(int i = 0; i < VEC_SIZE; i++) + // { + // kv_cache[dst_offset + i] = + // ck_tile::type_convert(ck_tile::type_convert(k_val[i]) / scale); + // } + if(threadIdx.x == 0) + { + const int64_t dst_scale_idx = + block_idx * cache_block_size * cache_stride + cache_block_size * head_dim + + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; + } + scale = 1.0f / scale; + vec_o* kv_cache_vec = reinterpret_cast(kv_cache + dst_offset); + *kv_cache_vec = ck_tile::vec_convert(k_val, scale); +} + +template +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size * + // 4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, // batch size + const int64_t token_stride, // stride for each token in dst_k + const int64_t head_dim, // dimension of each head + const int64_t block_stride, // stride for each block in kv_cache + const int64_t cache_token_stride, // stride for each token in kv_cache + const int64_t cache_block_size, // num_tokens for each block in kv_cache + const int num_blocks, // number of blocks + const int num_tokens, // number of tokens + const int quant_block_size // quantization block size +) +{ + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * BLOCK_Y_SIZE + threadIdx.y; + const int head_idx = (blockIdx.y * BLOCK_X_SIZE + threadIdx.x) * VEC_SIZE; + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for(int iter = 0; iter < (batch_size + BLOCK_X_SIZE - 1) / BLOCK_X_SIZE; iter++) + { + int tid = iter * BLOCK_X_SIZE + threadIdx.x; + if(tid < batch_size) + { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if(token_idx >= seq_start && token_idx < seq_end) + { + batch_idx[threadIdx.y] = tid; + } + } + } + + if(head_idx >= head_dim || token_idx >= num_tokens) + { + return; + } + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_idx = + block_table[batch_idx[threadIdx.y] * num_blocks + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_idx * block_stride; + const int64_t cache_inblock_offset = (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + if(threadIdx.x == 0) + { + const int64_t src_scale_offset = src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } } } // namespace aiter @@ -1378,25 +1552,71 @@ void reshape_and_cache_flash( // KV_T is the data type of key and value tensors. // CACHE_T is the stored data type of kv-cache. // KV_DTYPE is the real data type of kv-cache. -#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ - aiter::concat_and_cache_mla_kernel \ - <<>>( \ - reinterpret_cast(kv_c.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, entry_stride, \ - kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ - reinterpret_cast(scale.data_ptr())); - -#define CALL_CONCAT_AND_CACHE_MLA_OPT(KV_T, CACHE_T, KV_DTYPE) \ - aiter::concat_and_cache_mla_opt_kernel \ - <<>>( \ - reinterpret_cast(kv_c.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, entry_stride, \ - kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ - reinterpret_cast(scale.data_ptr())); +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + aiter::concat_and_cache_mla_kernel \ + <<>>(reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + block_stride, \ + entry_stride, \ + kv_c_stride, \ + k_pe_stride, \ + kv_lora_rank, \ + pe_dim, \ + block_size, \ + reinterpret_cast(scale.data_ptr())); + +#define CALL_CONCAT_AND_CACHE_MLA_OPT(KV_T, CACHE_T, KV_DTYPE) \ + aiter::concat_and_cache_mla_opt_kernel \ + <<>>(reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + block_stride, \ + entry_stride, \ + kv_c_stride, \ + k_pe_stride, \ + kv_lora_rank, \ + pe_dim, \ + block_size, \ + reinterpret_cast(scale.data_ptr())); + +// Macro to dispatch the kernel based on the data type. +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + aiter:: \ + indexer_k_quant_and_cache_kernel \ + <<>>(reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + num_tokens, \ + head_dim, \ + quant_block_size, \ + cache_block_size, \ + cache_stride, \ + use_ue8m0); + +// Macro to dispatch the kernel based on the data amount. +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + aiter::cp_gather_indexer_k_quant_cache_kernel<8, BLOCK_Y_SIZE> \ + <<>>(reinterpret_cast(kv_cache.data_ptr()), \ + reinterpret_cast(dst_k.data_ptr()), \ + reinterpret_cast(dst_scale.data_ptr()), \ + block_table.data_ptr(), \ + cu_seq_lens.data_ptr(), \ + batch_size, \ + dst_k.stride(0), \ + dst_k.size(1), \ + kv_cache.stride(0), \ + kv_cache.stride(1), \ + kv_cache.size(1), \ + block_table.size(1), \ + num_tokens, \ + quant_block_size); namespace aiter { @@ -1652,40 +1872,123 @@ void reshape_and_cache_with_block_quant_for_asm_pa( } } -void concat_and_cache_mla( - torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] - torch::Tensor& k_pe, // [num_tokens, pe_dim] - torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + - // pe_dim)] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& scale) { - int num_tokens = slot_mapping.size(0); - int kv_lora_rank = kv_c.size(1); - int pe_dim = k_pe.size(1); - int block_size = kv_cache.size(1); - - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - int kv_c_stride = kv_c.stride(0); - int k_pe_stride = k_pe.stride(0); - int block_stride = kv_cache.stride(0); - int entry_stride = kv_cache.stride(1); - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_c)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); - - if ((pe_dim & 0x7) == 0 && (kv_lora_rank & 0x7) == 0) { - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 1024) / 8); - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CONCAT_AND_CACHE_MLA_OPT); +void concat_and_cache_mla(torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& k_pe, // [num_tokens, pe_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, + torch::Tensor& scale) +{ + int num_tokens = slot_mapping.size(0); + int kv_lora_rank = kv_c.size(1); + int pe_dim = k_pe.size(1); + int block_size = kv_cache.size(1); + + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + int kv_c_stride = kv_c.stride(0); + int k_pe_stride = k_pe.stride(0); + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_c)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); - } else { - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CONCAT_AND_CACHE_MLA); - } + if((pe_dim & 0x7) == 0 && (kv_lora_rank & 0x7) == 0) + { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 1024) / 8); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, CALL_CONCAT_AND_CACHE_MLA_OPT); + } + else + { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, CALL_CONCAT_AND_CACHE_MLA); + } +} + +// copy from vllm: https://github.com/vllm-project/vllm/blob/main/csrc/cache_kernels.cu +void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt) +{ + int num_tokens = k.size(0); + int head_dim = k.size(1); + int cache_block_size = kv_cache.size(1); + int cache_stride = kv_cache.size(2); + bool use_ue8m0 = scale_fmt == "ue8m0"; + + TORCH_CHECK(k.device() == kv_cache.device(), "k and kv_cache must be on the same device"); + TORCH_CHECK(k.device() == slot_mapping.device(), + "k and slot_mapping must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size"); + + int quant_blocks = num_tokens * head_dim / quant_block_size; + const int vec_size = 16; + const int blockDimx = 8; + const int blockDimy = ck_tile::get_warp_size() / blockDimx; + dim3 grid((quant_blocks + blockDimy - 1) / (blockDimy)); + dim3 block(blockDimx, blockDimy); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(k)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", CALL_INDEXER_K_QUANT_AND_CACHE); } +// copy from vllm: https://github.com/vllm-project/vllm/blob/main/csrc/cache_kernels.cu +void cp_gather_indexer_k_quant_cache( + const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& dst_k, // [num_tokens, head_dim] + torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size] float + const torch::Tensor& block_table, // [batch_size, num_blocks] + const torch::Tensor& cu_seq_lens // [batch_size + 1] +) +{ + int batch_size = block_table.size(0); + int num_tokens = dst_k.size(0); + int head_dim = dst_k.size(1); + int quant_block_size = head_dim / (dst_scale.size(1) * dst_scale.itemsize() / 4); + + TORCH_CHECK(kv_cache.device() == dst_k.device(), + "kv_cache and dst_k must be on the same device"); + TORCH_CHECK(kv_cache.device() == dst_scale.device(), + "kv_cache and dst_scale must be on the same device"); + TORCH_CHECK(kv_cache.device() == block_table.device(), + "kv_cache and block_table must be on the same device"); + TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(), + "kv_cache and cu_seq_lens must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 16; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(kv_cache)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + + if(num_tokens < 32) + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } + else if(num_tokens < 64) + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } + else if(num_tokens < 128) + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } + else if(num_tokens < 256) + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } + else if(num_tokens < 512) + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } + else + { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } +} } // namespace aiter diff --git a/op_tests/test_indexer_k_quant_and_cache.py b/op_tests/test_indexer_k_quant_and_cache.py new file mode 100644 index 0000000000..d06530f3ab --- /dev/null +++ b/op_tests/test_indexer_k_quant_and_cache.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import aiter +from aiter.test_common import checkAllclose, run_perftest, benchmark +from aiter import dtypes +from aiter import pertoken_quant, dtypes, indexer_k_quant_and_cache +import argparse +import pandas as pd + +MAX_TOKEN_SUPPORTED = 16384 +torch.set_default_device("cuda") + + +def run_torch(k, kv_cache, slot_mapping, quant_block_size, scale_fmt): + num_token, head_dim = k.shape + block_size = kv_cache.shape[1] + per_token_amax, _ = torch.max( + input=torch.abs(k.view(-1, quant_block_size)), dim=-1, keepdim=True + ) + scale = per_token_amax / torch.finfo(dtypes.fp8).max + if scale_fmt == "ue8m0": + scale = torch.pow(2.0, torch.ceil(torch.log2(scale))) + k_fp8, scale = pertoken_quant( + k.view(-1, quant_block_size), quant_dtype=dtypes.fp8, scale=scale + ) + k_fp8 = k_fp8.view(num_token, head_dim) + for i in range(num_token): + slot = slot_mapping[i].item() + blockId = slot // block_size + block_offset = slot % block_size + kv_cache[blockId, block_offset, :head_dim] = k_fp8[i] + kv_cache[blockId, block_offset, head_dim:] = scale[i].view(dtypes.fp8) + + +@benchmark() +def test_indexer_k_quant_and_cache( + num_token, block_size, quant_block_size, head_dim=128 +): + assert ( + num_token <= MAX_TOKEN_SUPPORTED + ), f"test only support max_token={MAX_TOKEN_SUPPORTED}" + block_num = (num_token + block_size - 1) // block_size + k = torch.randn((num_token, head_dim), dtype=dtypes.bf16) + slot_mapping = torch.arange(0, num_token, 1, dtype=torch.int64) + scale_fmt = "ue8m0" + kv_cache = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + run_torch(k, kv_cache, slot_mapping, quant_block_size, scale_fmt) + kv_cache2 = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + _, us = run_perftest( + indexer_k_quant_and_cache, + k, + kv_cache2, + slot_mapping, + quant_block_size, + scale_fmt, + ) + err = checkAllclose( + kv_cache.view(-1, head_dim + 4)[:num_token].to(torch.float), + kv_cache2.view(-1, head_dim + 4)[:num_token].to(torch.float), + ) + # scale = kv_cache[:, :, head_dim:].view(torch.float) + # scale2 = kv_cache2[:, :, head_dim:].view(torch.float) + ret = {"aiter us": us, "aiter err": err} + try: + from vllm import _custom_ops as ops + + kv_cache3 = torch.empty((block_num, block_size, head_dim + 4), dtype=dtypes.fp8) + _, us2 = run_perftest( + ops.indexer_k_quant_and_cache, + k, + kv_cache3, + slot_mapping, + quant_block_size, + scale_fmt, + ) + err2 = checkAllclose( + kv_cache.view(-1, head_dim + 4)[:num_token].to(torch.float), + kv_cache3.view(-1, head_dim + 4)[:num_token].to(torch.float), + ) + ret.update({"vllm us": us2, "vllm err": err2}) + except Exception: + # Ignore all exceptions here because vllm._custom_ops is optional and may not be available. + pass + return ret + + +parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description="Test indexer_k_quant_and_cache.", +) +parser.add_argument( + "-m", + type=int, + nargs="*", + default=[1, 64, 128, 257, 1028, 16384], + help="""token num""", +) +parser.add_argument( + "-b", + "--block_size", + type=int, + nargs="*", + default=[1], + help="""block_size, default: 1""", +) + +args = parser.parse_args() +df = [] +for m in args.m: + for block_size in args.block_size: + ret = test_indexer_k_quant_and_cache(m, block_size, 128, 128) + df.append(ret) +df = pd.DataFrame(df) +aiter.logger.info(f"summary:\n{df}")