Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions aiter/ops/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,23 @@ def concat_and_cache_mla(
kv_cache_dtype: str,
scale: Tensor,
) -> None: ...


@compile_ops("module_cache")
def indexer_k_quant_and_cache(
k: Tensor,
kv_cache: Tensor,
slot_mapping: Tensor,
quant_block_size: int,
scale_fmt: str,
) -> None: ...


@compile_ops("module_cache")
def cp_gather_indexer_k_quant_cache(
kv_cache: Tensor,
dst_k: Tensor,
dst_scale: Tensor,
block_table: Tensor,
cu_seq_lens: Tensor,
) -> None: ...
26 changes: 19 additions & 7 deletions csrc/include/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,24 @@ void reshape_and_cache_with_block_quant_for_asm_pa(
const bool asm_layout,
const int ori_block_size = 128);

void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale);
void concat_and_cache_mla(torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype,
torch::Tensor& scale);

void indexer_k_quant_and_cache(torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);

void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
} // namespace aiter
20 changes: 19 additions & 1 deletion csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,25 @@
py::arg("kv_cache"), \
py::arg("slot_mapping"), \
py::arg("kv_cache_dtype"), \
py::arg("scale"));
py::arg("scale")); \
m.def("indexer_k_quant_and_cache", \
&aiter::indexer_k_quant_and_cache, \
"indexer_k_quant_and_cache(Tensor k, Tensor kv_cache," \
" Tensor slot_mapping," \
" int64_t quant_block_size," \
" std::string& scale_fmt) -> ()", \
py::arg("k"), \
py::arg("kv_cache"), \
py::arg("slot_mapping"), \
py::arg("quant_block_size"), \
py::arg("scale_fmt")); \
m.def("cp_gather_indexer_k_quant_cache", \
&aiter::cp_gather_indexer_k_quant_cache, \
py::arg("kv_cache"), \
py::arg("dst_k"), \
py::arg("dst_scale"), \
py::arg("block_table"), \
py::arg("cu_seq_lens"));

#define CUSTOM_ALL_REDUCE_PYBIND \
m.def("init_custom_ar", \
Expand Down
Loading