diff --git a/CMakeLists.txt b/CMakeLists.txt index 3db7ff0bbda2..abb458a05e3a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -301,6 +301,8 @@ set(VLLM_EXT_SRC "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" + "csrc/quantization/turboquant/polarquant_kernels.cu" + "csrc/quantization/turboquant/turboquant_attention_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/torch_bindings.cpp") diff --git a/csrc/cache.h b/csrc/cache.h index 0188a568edc7..0f34ca2a8ed3 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -74,6 +74,35 @@ void indexer_k_quant_and_cache( int64_t quant_block_size, // quantization block size const std::string& scale_fmt); +// TurboQuant KV cache quantization (PolarQuant + QJL) +void reshape_and_cache_turboquant( + torch::Tensor key, torch::Tensor value, torch::Tensor key_cache, + torch::Tensor value_cache, torch::Tensor slot_mapping, int64_t num_kv_heads, + int64_t head_size, int64_t block_size, const std::string& tq_type, + int64_t layer_seed, int64_t qjl_proj_dim); + +void turboquant_encode(torch::Tensor kv_data, torch::Tensor angles_out, + torch::Tensor radii_out, torch::Tensor qjl_out, + torch::Tensor residual_norms_out, + int64_t num_kv_heads, int64_t head_size, + const std::string& tq_type, int64_t layer_seed, + int64_t qjl_proj_dim); + +void turboquant_decode(torch::Tensor angles, torch::Tensor radii, + torch::Tensor qjl_bits, torch::Tensor residual_norms, + torch::Tensor kv_out, + int64_t num_kv_heads, int64_t head_size, + const std::string& tq_type, int64_t layer_seed, + int64_t qjl_proj_dim); + +void paged_attention_turboquant( + torch::Tensor output, torch::Tensor query, torch::Tensor key_cache, + torch::Tensor value_cache, torch::Tensor block_tables, + torch::Tensor context_lens, double scale, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, int64_t block_size, + int64_t max_blocks_per_seq, const std::string& tq_type, + int64_t layer_seed, int64_t qjl_proj_dim); + // Concatenate query nope and rope for MLA/DSA attention void concat_mla_q( torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] diff --git a/csrc/quantization/turboquant/polarquant_kernels.cu b/csrc/quantization/turboquant/polarquant_kernels.cu new file mode 100644 index 000000000000..9a28cce8967b --- /dev/null +++ b/csrc/quantization/turboquant/polarquant_kernels.cu @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// +// PolarQuant + QJL encode/decode CUDA kernels for TurboQuant KV cache +// quantization. +// +// Reference: Zandieh et al., "TurboQuant: Redefining AI Efficiency with +// Extreme Compression", ICLR 2026 (arXiv:2504.19874) + +#include +#include +#include + +#include "turboquant_utils.cuh" + +using namespace vllm::turboquant; + +// ============================================================================ +// Encode kernel: KV vector → packed TurboQuant representation +// ============================================================================ +// +// Each thread block handles one (token, head) pair. +// Grid: (num_tokens, num_kv_heads) + +template +__global__ void turboquant_encode_kernel( + const float* __restrict__ kv_data, // [num_tokens, num_kv_heads, head_size] + uint8_t* __restrict__ angles_out, // packed angle storage + half* __restrict__ radii_out, // [num_tokens, num_kv_heads] + uint8_t* __restrict__ qjl_out, // packed QJL sign bits (nullptr if PQ4) + half* __restrict__ residual_norms_out, // [num_tokens, num_kv_heads] (QJL) + int num_kv_heads, int head_size, uint32_t layer_seed, int qjl_proj_dim) { + int token_idx = blockIdx.x; + int head_idx = blockIdx.y; + + // Pointer to this head's KV vector + const float* vec = + kv_data + (token_idx * num_kv_heads + head_idx) * head_size; + + // Output pointers + constexpr int BITS = angle_bits(DT); + int num_angles = head_size - 1; + int angle_bytes_per_head = (num_angles * BITS + 7) / 8; + int qjl_bytes_per_head = has_qjl(DT) ? (qjl_proj_dim + 7) / 8 : 0; + + int head_offset = token_idx * num_kv_heads + head_idx; + uint8_t* angles_ptr = angles_out + head_offset * angle_bytes_per_head; + half* radius_ptr = radii_out + head_offset; + uint8_t* qjl_ptr = + qjl_out ? (qjl_out + head_offset * qjl_bytes_per_head) : nullptr; + half* residual_norm_ptr = + residual_norms_out ? (residual_norms_out + head_offset) : nullptr; + + // Derive per-head seeds + uint32_t rotation_seed = derive_rotation_seed(layer_seed, head_idx); + uint32_t qjl_seed = derive_qjl_seed(layer_seed, head_idx); + + // Single-thread encode (thread 0 of each block) + // TODO: Parallelize across threads within a block for larger head sizes + if (threadIdx.x == 0) { + turboquant_encode_head
(vec, head_size, rotation_seed, qjl_seed, + angles_ptr, radius_ptr, qjl_ptr, + residual_norm_ptr, qjl_proj_dim); + } +} + +// ============================================================================ +// Decode kernel: packed TurboQuant → reconstructed KV vector +// ============================================================================ + +template +__global__ void turboquant_decode_kernel( + const uint8_t* __restrict__ angles, + const half* __restrict__ radii, + const uint8_t* __restrict__ qjl_bits, + const half* __restrict__ residual_norms, // [num_tokens, num_kv_heads] + float* __restrict__ kv_out, // [num_tokens, num_kv_heads, head_size] + int num_kv_heads, int head_size, uint32_t layer_seed, int qjl_proj_dim) { + int token_idx = blockIdx.x; + int head_idx = blockIdx.y; + + constexpr int BITS = angle_bits(DT); + int num_angles = head_size - 1; + int angle_bytes_per_head = (num_angles * BITS + 7) / 8; + int qjl_bytes_per_head = has_qjl(DT) ? (qjl_proj_dim + 7) / 8 : 0; + + int head_offset = token_idx * num_kv_heads + head_idx; + const uint8_t* angles_ptr = angles + head_offset * angle_bytes_per_head; + half radius = radii[head_offset]; + const uint8_t* qjl_ptr = + qjl_bits ? (qjl_bits + head_offset * qjl_bytes_per_head) : nullptr; + half residual_norm = residual_norms ? residual_norms[head_offset] + : __float2half(0.0f); + + float* out_ptr = + kv_out + (token_idx * num_kv_heads + head_idx) * head_size; + + uint32_t rotation_seed = derive_rotation_seed(layer_seed, head_idx); + uint32_t qjl_seed = derive_qjl_seed(layer_seed, head_idx); + + if (threadIdx.x == 0) { + turboquant_decode_head
(angles_ptr, radius, qjl_ptr, residual_norm, + head_size, rotation_seed, qjl_seed, + qjl_proj_dim, out_ptr); + } +} + +// ============================================================================ +// Reshape-and-cache kernel for TurboQuant +// ============================================================================ +// +// This integrates with vLLM's PagedAttention block layout. +// Instead of storing raw KV values, we store the polar-encoded representation. +// +// Each block stores: +// - Packed angles: [num_kv_heads, block_size, angle_bytes_per_head] +// - Radii: [num_kv_heads, block_size] (fp16) +// - QJL bits: [num_kv_heads, block_size, qjl_bytes_per_head] (if enabled) + +template +__global__ void reshape_and_cache_turboquant_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_kv_heads, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, block_bytes_k] + uint8_t* __restrict__ value_cache, // [num_blocks, block_bytes_v] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + int num_kv_heads, int head_size, int block_size, uint32_t layer_seed, + int qjl_proj_dim) { + int token_idx = blockIdx.x; + int head_idx = blockIdx.y; + // 0 = key, 1 = value + int kv_idx = blockIdx.z; + + if (threadIdx.x != 0) return; + + int64_t slot = slot_mapping[token_idx]; + if (slot < 0) return; // Padding token + + int block_idx = slot / block_size; + int block_offset = slot % block_size; + + // Select key or value input + const scalar_t* kv_data = + (kv_idx == 0) ? key : value; + uint8_t* cache = (kv_idx == 0) ? key_cache : value_cache; + + const scalar_t* vec = + kv_data + (token_idx * num_kv_heads + head_idx) * head_size; + + // Convert input to float + float vec_f[MAX_HEAD_SIZE]; + for (int i = 0; i < head_size; i++) { + vec_f[i] = static_cast(vec[i]); + } + + // Calculate storage layout within a block using padded alignment + constexpr int BITS = angle_bits(DT); + int angle_bytes = padded_angle_bytes(head_size, BITS); + int radius_bytes = 2; // fp16, always 2-byte aligned after padded angles + int qjl_bytes = has_qjl(DT) ? (qjl_proj_dim + 7) / 8 : 0; + int residual_norm_bytes = has_qjl(DT) ? 2 : 0; + + int bpt_per_head = angle_bytes + radius_bytes + qjl_bytes + + residual_norm_bytes; + int bytes_per_token = num_kv_heads * bpt_per_head; + int block_total_bytes = block_size * bytes_per_token; + + // Offset within block for this (token_in_block, head) + int token_head_offset = + block_offset * bytes_per_token + head_idx * bpt_per_head; + + uint8_t* block_ptr = cache + block_idx * block_total_bytes; + uint8_t* angles_ptr = block_ptr + token_head_offset; + // radius is 2-byte aligned because angle_bytes is always even + half* radius_ptr = + reinterpret_cast(angles_ptr + angle_bytes); + uint8_t* qjl_ptr = + has_qjl(DT) ? (angles_ptr + angle_bytes + radius_bytes) : nullptr; + // residual_norm is stored after QJL bits (also needs 2-byte alignment, + // guaranteed because qjl_bytes rounds to even for power-of-2 head sizes) + half* residual_norm_ptr = + has_qjl(DT) + ? reinterpret_cast(angles_ptr + angle_bytes + radius_bytes + + qjl_bytes) + : nullptr; + + uint32_t rotation_seed = derive_rotation_seed(layer_seed, head_idx); + uint32_t qjl_seed = derive_qjl_seed(layer_seed, head_idx); + + turboquant_encode_head
(vec_f, head_size, rotation_seed, qjl_seed, + angles_ptr, radius_ptr, qjl_ptr, + residual_norm_ptr, qjl_proj_dim); +} + +// ============================================================================ +// Host-callable wrappers +// ============================================================================ + +void turboquant_encode(torch::Tensor kv_data, torch::Tensor angles_out, + torch::Tensor radii_out, torch::Tensor qjl_out, + torch::Tensor residual_norms_out, + int64_t num_kv_heads, int64_t head_size, + const std::string& tq_type, int64_t layer_seed, + int64_t qjl_proj_dim) { + int num_tokens = kv_data.size(0); + dim3 grid(num_tokens, num_kv_heads); + dim3 block(1); // Single thread per (token, head) for now + + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (tq_type == "pq4") { + turboquant_encode_kernel<<>>( + kv_data.data_ptr(), angles_out.data_ptr(), + reinterpret_cast(radii_out.data_ptr()), + nullptr, nullptr, + num_kv_heads, head_size, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq3") { + turboquant_encode_kernel<<>>( + kv_data.data_ptr(), angles_out.data_ptr(), + reinterpret_cast(radii_out.data_ptr()), + qjl_out.data_ptr(), + reinterpret_cast(residual_norms_out.data_ptr()), + num_kv_heads, head_size, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq2") { + turboquant_encode_kernel<<>>( + kv_data.data_ptr(), angles_out.data_ptr(), + reinterpret_cast(radii_out.data_ptr()), + qjl_out.data_ptr(), + reinterpret_cast(residual_norms_out.data_ptr()), + num_kv_heads, head_size, layer_seed, qjl_proj_dim); + } +} + +void turboquant_decode(torch::Tensor angles, torch::Tensor radii, + torch::Tensor qjl_bits, torch::Tensor residual_norms, + torch::Tensor kv_out, + int64_t num_kv_heads, int64_t head_size, + const std::string& tq_type, int64_t layer_seed, + int64_t qjl_proj_dim) { + int num_tokens = kv_out.size(0); + dim3 grid(num_tokens, num_kv_heads); + dim3 block(1); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (tq_type == "pq4") { + turboquant_decode_kernel<<>>( + angles.data_ptr(), + reinterpret_cast(radii.data_ptr()), + nullptr, nullptr, + kv_out.data_ptr(), num_kv_heads, head_size, + layer_seed, qjl_proj_dim); + } else if (tq_type == "tq3") { + turboquant_decode_kernel<<>>( + angles.data_ptr(), + reinterpret_cast(radii.data_ptr()), + qjl_bits.data_ptr(), + reinterpret_cast(residual_norms.data_ptr()), + kv_out.data_ptr(), num_kv_heads, + head_size, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq2") { + turboquant_decode_kernel<<>>( + angles.data_ptr(), + reinterpret_cast(radii.data_ptr()), + qjl_bits.data_ptr(), + reinterpret_cast(residual_norms.data_ptr()), + kv_out.data_ptr(), num_kv_heads, + head_size, layer_seed, qjl_proj_dim); + } +} + +void reshape_and_cache_turboquant( + torch::Tensor key, torch::Tensor value, torch::Tensor key_cache, + torch::Tensor value_cache, torch::Tensor slot_mapping, int64_t num_kv_heads, + int64_t head_size, int64_t block_size, const std::string& tq_type, + int64_t layer_seed, int64_t qjl_proj_dim) { + int num_tokens = key.size(0); + // Grid: (num_tokens, num_kv_heads, 2) for key and value + dim3 grid(num_tokens, num_kv_heads, 2); + dim3 block(1); + + const auto stream = at::cuda::getCurrentCUDAStream(); + + // Dispatch by input dtype and TQ type + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, key.scalar_type(), + "reshape_and_cache_turboquant", [&] { + if (tq_type == "pq4") { + reshape_and_cache_turboquant_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), num_kv_heads, head_size, + block_size, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq3") { + reshape_and_cache_turboquant_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), num_kv_heads, head_size, + block_size, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq2") { + reshape_and_cache_turboquant_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), num_kv_heads, head_size, + block_size, layer_seed, qjl_proj_dim); + } + }); +} diff --git a/csrc/quantization/turboquant/turboquant_attention_kernels.cu b/csrc/quantization/turboquant/turboquant_attention_kernels.cu new file mode 100644 index 000000000000..66f7e0749777 --- /dev/null +++ b/csrc/quantization/turboquant/turboquant_attention_kernels.cu @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// +// Fused PagedAttention kernels for TurboQuant KV cache. +// +// These kernels extend vLLM's PagedAttention v2 to dequantize KV cache entries +// on-the-fly during the dot-product computation. The polar decode and optional +// QJL bias correction are performed in registers/shared memory — full-precision +// KV tensors are never materialized in HBM. +// +// Reference: Zandieh et al., "TurboQuant: Redefining AI Efficiency with +// Extreme Compression", ICLR 2026 (arXiv:2504.19874) + +#include +#include +#include + +#include "turboquant_utils.cuh" + +using namespace vllm::turboquant; + +// ============================================================================ +// Fused Paged Attention with TurboQuant dequantization +// ============================================================================ +// +// This kernel performs single-query attention (one query per sequence) with +// on-the-fly dequantization of the KV cache from TurboQuant format. +// +// Layout assumptions: +// query: [num_seqs, num_heads, head_size] +// key_cache: [num_blocks, block_bytes_k] (packed TurboQuant) +// value_cache: [num_blocks, block_bytes_v] (packed TurboQuant) +// block_tables: [num_seqs, max_num_blocks_per_seq] +// context_lens: [num_seqs] +// +// Each warp handles one query head. We iterate over KV blocks, dequantize +// each KV token, compute QK dot products, and accumulate the softmax-weighted +// value output. + +template +__global__ void paged_attention_turboquant_kernel( + float* __restrict__ output, // [num_seqs, num_heads, head_size] + const float* __restrict__ query, // [num_seqs, num_heads, head_size] + const uint8_t* __restrict__ key_cache, // [num_blocks, block_bytes] + const uint8_t* __restrict__ value_cache, + const int* __restrict__ block_tables, // [num_seqs, max_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + float scale, // attention scale (1/sqrt(d)) + int num_heads, int num_kv_heads, int head_size, int block_size, + int max_blocks_per_seq, uint32_t layer_seed, int qjl_proj_dim) { + int seq_idx = blockIdx.x; + int head_idx = blockIdx.y; + + // GQA: map query head to KV head + int kv_head_idx = head_idx / (num_heads / num_kv_heads); + + int context_len = context_lens[seq_idx]; + if (context_len == 0) return; + + int num_blocks = (context_len + block_size - 1) / block_size; + + // Load query vector into registers + float q[MAX_HEAD_SIZE]; + const float* q_ptr = + query + (seq_idx * num_heads + head_idx) * head_size; + for (int i = 0; i < head_size; i++) { + q[i] = q_ptr[i]; + } + + // Block layout sizes (using padded alignment from turboquant_utils.cuh) + constexpr int BITS = angle_bits(DT); + int angle_bytes = padded_angle_bytes(head_size, BITS); + int radius_bytes = 2; + int qjl_bytes = has_qjl(DT) ? (qjl_proj_dim + 7) / 8 : 0; + int residual_norm_bytes = has_qjl(DT) ? 2 : 0; + int bpt_per_head = angle_bytes + radius_bytes + qjl_bytes + + residual_norm_bytes; + int bytes_per_token = num_kv_heads * bpt_per_head; + int block_total_bytes = block_size * bytes_per_token; + + // Per-head seeds + uint32_t rotation_seed = derive_rotation_seed(layer_seed, kv_head_idx); + uint32_t qjl_seed = derive_qjl_seed(layer_seed, kv_head_idx); + + // Softmax tracking + float max_logit = -1e20f; + float sum_exp = 0.0f; + float acc[MAX_HEAD_SIZE]; + for (int i = 0; i < head_size; i++) { + acc[i] = 0.0f; + } + + // Iterate over blocks + const int* block_table = + block_tables + seq_idx * max_blocks_per_seq; + + for (int block_i = 0; block_i < num_blocks; block_i++) { + int physical_block_idx = block_table[block_i]; + int tokens_in_block = + (block_i == num_blocks - 1) + ? (context_len - block_i * block_size) + : block_size; + + for (int token_i = 0; token_i < tokens_in_block; token_i++) { + // --- Decode key --- + int token_head_offset = + token_i * bytes_per_token + + kv_head_idx * bpt_per_head; + const uint8_t* k_block_ptr = + key_cache + physical_block_idx * block_total_bytes; + const uint8_t* k_angles_ptr = k_block_ptr + token_head_offset; + half k_radius = *reinterpret_cast( + k_angles_ptr + angle_bytes); + const uint8_t* k_qjl_ptr = + has_qjl(DT) + ? (k_angles_ptr + angle_bytes + radius_bytes) + : nullptr; + half k_residual_norm = + has_qjl(DT) + ? *reinterpret_cast( + k_angles_ptr + angle_bytes + radius_bytes + qjl_bytes) + : __float2half(0.0f); + + float k_vec[MAX_HEAD_SIZE]; + turboquant_decode_head
(k_angles_ptr, k_radius, k_qjl_ptr, + k_residual_norm, head_size, rotation_seed, + qjl_seed, qjl_proj_dim, k_vec); + + // Compute QK dot product + float logit = 0.0f; + for (int i = 0; i < head_size; i++) { + logit += q[i] * k_vec[i]; + } + logit *= scale; + + // --- Online softmax update --- + float old_max = max_logit; + max_logit = fmaxf(max_logit, logit); + float correction = expf(old_max - max_logit); + sum_exp = sum_exp * correction + expf(logit - max_logit); + + // Rescale running accumulator + for (int i = 0; i < head_size; i++) { + acc[i] *= correction; + } + + // --- Decode value and accumulate --- + const uint8_t* v_block_ptr = + value_cache + physical_block_idx * block_total_bytes; + const uint8_t* v_angles_ptr = v_block_ptr + token_head_offset; + half v_radius = *reinterpret_cast( + v_angles_ptr + angle_bytes); + const uint8_t* v_qjl_ptr = + has_qjl(DT) + ? (v_angles_ptr + angle_bytes + radius_bytes) + : nullptr; + half v_residual_norm = + has_qjl(DT) + ? *reinterpret_cast( + v_angles_ptr + angle_bytes + radius_bytes + qjl_bytes) + : __float2half(0.0f); + + float v_vec[MAX_HEAD_SIZE]; + turboquant_decode_head
(v_angles_ptr, v_radius, v_qjl_ptr, + v_residual_norm, head_size, rotation_seed, + qjl_seed, qjl_proj_dim, v_vec); + + float weight = expf(logit - max_logit); + for (int i = 0; i < head_size; i++) { + acc[i] += weight * v_vec[i]; + } + } + } + + // Normalize by softmax denominator + float inv_sum = 1.0f / sum_exp; + float* out_ptr = + output + (seq_idx * num_heads + head_idx) * head_size; + for (int i = 0; i < head_size; i++) { + out_ptr[i] = acc[i] * inv_sum; + } +} + +// ============================================================================ +// Host-callable wrapper +// ============================================================================ + +void paged_attention_turboquant( + torch::Tensor output, torch::Tensor query, torch::Tensor key_cache, + torch::Tensor value_cache, torch::Tensor block_tables, + torch::Tensor context_lens, double scale, int64_t num_heads, + int64_t num_kv_heads, int64_t head_size, int64_t block_size, + int64_t max_blocks_per_seq, const std::string& tq_type, + int64_t layer_seed, int64_t qjl_proj_dim) { + int num_seqs = query.size(0); + dim3 grid(num_seqs, num_heads); + dim3 block(1); // Single thread per (seq, head) — to be optimized with + // warp-level parallelism in future PRs + + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (tq_type == "pq4") { + paged_attention_turboquant_kernel + <<>>( + output.data_ptr(), query.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + block_tables.data_ptr(), context_lens.data_ptr(), + static_cast(scale), + num_heads, num_kv_heads, head_size, block_size, + max_blocks_per_seq, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq3") { + paged_attention_turboquant_kernel + <<>>( + output.data_ptr(), query.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + block_tables.data_ptr(), context_lens.data_ptr(), + static_cast(scale), + num_heads, num_kv_heads, head_size, block_size, + max_blocks_per_seq, layer_seed, qjl_proj_dim); + } else if (tq_type == "tq2") { + paged_attention_turboquant_kernel + <<>>( + output.data_ptr(), query.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + block_tables.data_ptr(), context_lens.data_ptr(), + static_cast(scale), + num_heads, num_kv_heads, head_size, block_size, + max_blocks_per_seq, layer_seed, qjl_proj_dim); + } +} diff --git a/csrc/quantization/turboquant/turboquant_utils.cuh b/csrc/quantization/turboquant/turboquant_utils.cuh new file mode 100644 index 000000000000..0e31d87f6b9f --- /dev/null +++ b/csrc/quantization/turboquant/turboquant_utils.cuh @@ -0,0 +1,503 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +#pragma once + +#include +#include +#include +#include + +namespace vllm { +namespace turboquant { + +// ============================================================================ +// TurboQuant KV cache data types +// ============================================================================ + +enum class TQDataType { + kPQ4 = 0, // PolarQuant 4-bit (no QJL) + kTQ3 = 1, // PolarQuant 3-bit + 1-bit QJL residual + kTQ2 = 2, // PolarQuant 2-bit + 1-bit QJL residual +}; + +// Bits per angle element for each mode +__host__ __device__ constexpr int angle_bits(TQDataType dt) { + return dt == TQDataType::kPQ4 ? 4 : (dt == TQDataType::kTQ3 ? 3 : 2); +} + +// Whether QJL residual correction is enabled +__host__ __device__ constexpr bool has_qjl(TQDataType dt) { + return dt != TQDataType::kPQ4; +} + +// Maximum supported head dimension. All stack buffers use this size. +// head_size must be a power of 2 and <= MAX_HEAD_SIZE. +constexpr int MAX_HEAD_SIZE = 256; + +// ============================================================================ +// Fast Walsh-Hadamard Transform (WHT) +// ============================================================================ + +// In-place normalized WHT on a vector of length `n` (must be power of 2). +// Each thread handles one element. Uses warp shuffle for dimensions <= 32, +// shared memory for larger dimensions. +template +__device__ void wht_inplace(scalar_t* __restrict__ vec, int n) { + // Butterfly-based WHT: O(n log n) + for (int half = 1; half < n; half <<= 1) { + for (int i = 0; i < n; i += (half << 1)) { + for (int j = 0; j < half; j++) { + scalar_t a = vec[i + j]; + scalar_t b = vec[i + j + half]; + vec[i + j] = a + b; + vec[i + j + half] = a - b; + } + } + } + // Normalize by 1/sqrt(n) + float norm = rsqrtf(static_cast(n)); + for (int i = 0; i < n; i++) { + vec[i] = static_cast(static_cast(vec[i]) * norm); + } +} + +// ============================================================================ +// Seeded random sign generation for WHT rotation +// ============================================================================ + +// Simple hash-based PRNG for deterministic random signs. +// Uses layer_seed + element_index to produce a sign (+1 or -1). +__device__ __forceinline__ float random_sign(uint32_t seed, int idx) { + // Murmur-style hash mixing + uint32_t h = seed ^ static_cast(idx); + h ^= h >> 16; + h *= 0x85ebca6bu; + h ^= h >> 13; + h *= 0xc2b2ae35u; + h ^= h >> 16; + return (h & 1u) ? 1.0f : -1.0f; +} + +// ============================================================================ +// Randomized Hadamard rotation (diagonal * WHT) +// ============================================================================ + +// Apply D * H to vec, where D is a random diagonal sign matrix and H is WHT. +// This concentrates the angle distribution for better uniform quantization. +template +__device__ void randomized_hadamard(scalar_t* __restrict__ vec, int n, + uint32_t seed) { + // Step 1: Multiply by random diagonal signs + for (int i = 0; i < n; i++) { + float s = random_sign(seed, i); + vec[i] = static_cast(static_cast(vec[i]) * s); + } + // Step 2: Apply WHT + wht_inplace(vec, n); +} + +// Inverse: H^T * D^T = H * D (since H is symmetric and D is its own inverse) +template +__device__ void inverse_randomized_hadamard(scalar_t* __restrict__ vec, int n, + uint32_t seed) { + // Step 1: Apply WHT (H is symmetric, H^T = H, and H * H = I up to scale) + wht_inplace(vec, n); + // Step 2: Multiply by same random diagonal signs (D^-1 = D for sign matrix) + for (int i = 0; i < n; i++) { + float s = random_sign(seed, i); + vec[i] = static_cast(static_cast(vec[i]) * s); + } +} + +// ============================================================================ +// Polar coordinate transforms +// ============================================================================ + +// Convert a pair (x1, x2) to polar coordinates (r, theta). +// Returns theta normalized to [0, 1) range for uniform quantization. +__device__ __forceinline__ void cartesian_to_polar(float x1, float x2, + float& r, float& theta) { + r = hypotf(x1, x2); + // atan2 returns [-pi, pi], normalize to [0, 1) + float angle = atan2f(x2, x1); + theta = (angle + CUDART_PI_F) / (2.0f * CUDART_PI_F); + // Clamp to [0, 1) to avoid edge cases + theta = fminf(fmaxf(theta, 0.0f), 0.999999f); +} + +// Convert polar coordinates back to Cartesian. +__device__ __forceinline__ void polar_to_cartesian(float r, float theta, + float& x1, float& x2) { + // theta is in [0, 1), convert back to [-pi, pi] + float angle = theta * (2.0f * CUDART_PI_F) - CUDART_PI_F; + float s, c; + __sincosf(angle, &s, &c); + x1 = r * c; + x2 = r * s; +} + +// ============================================================================ +// Uniform quantization of angles +// ============================================================================ + +// Quantize a [0, 1) value to `bits`-bit unsigned integer. +__device__ __forceinline__ uint8_t quantize_angle(float theta, int bits) { + int levels = 1 << bits; + int q = __float2int_rn(theta * levels); + q = min(max(q, 0), levels - 1); + return static_cast(q); +} + +// Dequantize a `bits`-bit unsigned integer back to [0, 1) midpoint. +__device__ __forceinline__ float dequantize_angle(uint8_t q, int bits) { + int levels = 1 << bits; + return (static_cast(q) + 0.5f) / static_cast(levels); +} + +// ============================================================================ +// Recursive polar folding +// ============================================================================ + +// PolarQuant encodes a d-dimensional vector by: +// 1. Pairing consecutive elements → d/2 (r, theta) pairs +// 2. Recursively pairing radii → d/4 (R, phi) pairs, etc. +// 3. Until a single radius remains +// +// Total angles stored: d/2 + d/4 + ... + 1 = d - 1 angles + 1 radius +// For head_size=128: 127 angles + 1 radius per head per token + +// Encode: convert d floats → (d-1) quantized angles + 1 fp16 radius. +// `angles_out` must have space for (d-1) entries. +// Returns the final radius. +template +__device__ float polar_encode(const float* __restrict__ vec, int d, + uint8_t* __restrict__ angles_out) { + // Working buffer for radii at current level + float radii[MAX_HEAD_SIZE]; + + int angle_idx = 0; + int n = d; + + // First level: pair original elements + for (int i = 0; i < n / 2; i++) { + float r, theta; + cartesian_to_polar(vec[2 * i], vec[2 * i + 1], r, theta); + angles_out[angle_idx++] = quantize_angle(theta, BITS); + radii[i] = r; + } + n /= 2; + + // Recursive folding: pair radii + while (n > 1) { + for (int i = 0; i < n / 2; i++) { + float r, theta; + cartesian_to_polar(radii[2 * i], radii[2 * i + 1], r, theta); + angles_out[angle_idx++] = quantize_angle(theta, BITS); + radii[i] = r; + } + n /= 2; + } + + return radii[0]; // Single remaining radius +} + +// Decode: reconstruct d floats from (d-1) quantized angles + 1 fp16 radius. +template +__device__ void polar_decode(const uint8_t* __restrict__ angles, + float radius, int d, + float* __restrict__ vec_out) { + // We need to reverse the encoding process. + // The angles are stored in order: first d/2 (leaf level), then d/4, etc. + // We reconstruct from the top (single radius) down to the leaves. + + // Count angles per level + // Level 0 (leaves): d/2 angles + // Level 1: d/4 angles + // ... + // Level log2(d)-1: 1 angle + + int num_levels = 0; + { + int n = d; + while (n > 1) { + num_levels++; + n /= 2; + } + } + + // Compute offset of each level in the angles array + // Level 0 starts at 0, has d/2 angles + // Level 1 starts at d/2, has d/4 angles + // etc. + int level_offset[16]; // Enough for head_size up to 65536 + int level_size[16]; + { + int off = 0; + int n = d; + for (int lvl = 0; lvl < num_levels; lvl++) { + level_size[lvl] = n / 2; + level_offset[lvl] = off; + off += n / 2; + n /= 2; + } + } + + // Start with the single radius at the top. Buffer must be large enough to + // hold the fully expanded d-dimensional vector at the leaf level. + float radii_current[MAX_HEAD_SIZE]; + float radii_next[MAX_HEAD_SIZE]; + radii_current[0] = radius; + + // Reconstruct from top level down + for (int lvl = num_levels - 1; lvl >= 0; lvl--) { + int n_pairs = level_size[lvl]; + int n_radii_in = n_pairs; // Number of radii to expand + // Each input radius + angle → two output values + for (int i = 0; i < n_radii_in; i++) { + float theta = dequantize_angle(angles[level_offset[lvl] + i], BITS); + float x1, x2; + polar_to_cartesian(radii_current[i], theta, x1, x2); + radii_next[2 * i] = x1; + radii_next[2 * i + 1] = x2; + } + // Swap buffers + for (int i = 0; i < 2 * n_radii_in; i++) { + radii_current[i] = radii_next[i]; + } + } + + // radii_current now contains the reconstructed d-dimensional vector + for (int i = 0; i < d; i++) { + vec_out[i] = radii_current[i]; + } +} + +// ============================================================================ +// QJL sign-bit projection +// ============================================================================ + +// QJL (Quantized Johnson-Lindenstrauss) applies a random sign-bit projection +// to the quantization residual for bias correction. +// +// Encode: sign_bits[j] = sign(dot(residual, random_vector_j)) +// Decode: correction = (1/m) * sum_j(sign_bits[j] * random_vector_j) +// +// The random vectors are generated deterministically from a seed, so they +// don't need to be stored. + +// Compute sign bit for one projection dimension. +// The random projection vector is generated from seed + proj_idx. +__device__ __forceinline__ bool qjl_sign_bit(const float* __restrict__ residual, + int d, uint32_t seed, + int proj_idx) { + float dot = 0.0f; + // Generate random sign vector and compute dot product + uint32_t proj_seed = seed ^ (static_cast(proj_idx) * 0x9e3779b9u); + for (int i = 0; i < d; i++) { + float s = random_sign(proj_seed, i); + dot += residual[i] * s; + } + return dot >= 0.0f; +} + +// Pack sign bits into bytes. `num_proj` sign bits → ceil(num_proj/8) bytes. +__device__ void qjl_encode(const float* __restrict__ residual, int d, + uint32_t seed, int num_proj, + uint8_t* __restrict__ sign_bits_out) { + for (int byte_idx = 0; byte_idx < (num_proj + 7) / 8; byte_idx++) { + uint8_t byte_val = 0; + for (int bit = 0; bit < 8; bit++) { + int proj_idx = byte_idx * 8 + bit; + if (proj_idx < num_proj) { + if (qjl_sign_bit(residual, d, seed, proj_idx)) { + byte_val |= (1u << bit); + } + } + } + sign_bits_out[byte_idx] = byte_val; + } +} + +// Reconstruct the bias correction vector from sign bits. +// +// Per the QJL estimator (Zandieh et al., AAAI 2025), the unbiased +// reconstruction is: +// correction[i] = (||r|| * sqrt(pi/2) / m) * sum_j(sign_j * g_j[i]) +// +// where ||r|| is the L2 norm of the quantization residual (stored as fp16 +// alongside the sign bits), m is the number of projections, and g_j is the +// j-th random sign vector. +__device__ void qjl_decode_correction(const uint8_t* __restrict__ sign_bits, + float residual_norm, + int d, uint32_t seed, int num_proj, + float* __restrict__ correction_out) { + // Initialize correction to zero + for (int i = 0; i < d; i++) { + correction_out[i] = 0.0f; + } + + // sqrt(pi/2) ≈ 1.2533141 + constexpr float SQRT_PI_OVER_2 = 1.2533141f; + float scale = residual_norm * SQRT_PI_OVER_2 / static_cast(num_proj); + + for (int proj_idx = 0; proj_idx < num_proj; proj_idx++) { + // Read the sign bit + int byte_idx = proj_idx / 8; + int bit_idx = proj_idx % 8; + float sign = (sign_bits[byte_idx] & (1u << bit_idx)) ? 1.0f : -1.0f; + + // Generate and accumulate scaled random vector + uint32_t proj_seed = seed ^ (static_cast(proj_idx) * 0x9e3779b9u); + for (int i = 0; i < d; i++) { + float s = random_sign(proj_seed, i); + correction_out[i] += sign * s * scale; + } + } +} + +// ============================================================================ +// Combined TurboQuant encode/decode (single head, single token) +// ============================================================================ + +// Full TurboQuant encode for one KV head vector. +// Input: vec[head_size] (fp16/bf16/fp32) +// Output: packed angles, radius, and optionally QJL sign bits +// +// IMPORTANT: head_size must be a power of 2 and <= MAX_HEAD_SIZE (256). +// This is required by the Walsh-Hadamard Transform and the fixed-size stack +// buffers used throughout the encode/decode pipeline. + +template +__device__ void turboquant_encode_head( + const float* __restrict__ vec, int head_size, + uint32_t rotation_seed, uint32_t qjl_seed, + uint8_t* __restrict__ angles_out, // (head_size - 1) angle entries + half* __restrict__ radius_out, // 1 fp16 radius + uint8_t* __restrict__ qjl_out, // QJL sign bits (if enabled) + half* __restrict__ residual_norm_out, // QJL residual L2 norm (if enabled) + int qjl_proj_dim) { + constexpr int BITS = angle_bits(DT); + + // Validate head_size at runtime (power of 2 and within buffer limits) + assert(head_size <= MAX_HEAD_SIZE && "head_size exceeds MAX_HEAD_SIZE"); + assert((head_size & (head_size - 1)) == 0 && "head_size must be power of 2"); + + // Working buffer + float buf[MAX_HEAD_SIZE]; + for (int i = 0; i < head_size; i++) { + buf[i] = vec[i]; + } + + // Step 1: Apply randomized Hadamard rotation + randomized_hadamard(buf, head_size, rotation_seed); + + // Step 2: Polar encode + float radius = polar_encode(buf, head_size, angles_out); + *radius_out = __float2half(radius); + + // Step 3: QJL residual correction (for tq2/tq3 modes) + if constexpr (has_qjl(DT)) { + // Compute residual = original_rotated - dequantized + float decoded[MAX_HEAD_SIZE]; + polar_decode(angles_out, radius, head_size, decoded); + + float residual[MAX_HEAD_SIZE]; + for (int i = 0; i < head_size; i++) { + residual[i] = buf[i] - decoded[i]; + } + + // Compute and store residual L2 norm (needed for unbiased QJL decode) + float norm_sq = 0.0f; + for (int i = 0; i < head_size; i++) { + norm_sq += residual[i] * residual[i]; + } + *residual_norm_out = __float2half(sqrtf(norm_sq)); + + // Encode residual as sign bits + qjl_encode(residual, head_size, qjl_seed, qjl_proj_dim, qjl_out); + } +} + +// Full TurboQuant decode for one KV head vector. +template +__device__ void turboquant_decode_head( + const uint8_t* __restrict__ angles, half radius_fp16, + const uint8_t* __restrict__ qjl_bits, half residual_norm_fp16, + int head_size, uint32_t rotation_seed, uint32_t qjl_seed, + int qjl_proj_dim, + float* __restrict__ vec_out) { + constexpr int BITS = angle_bits(DT); + + float radius = __half2float(radius_fp16); + + // Step 1: Polar decode + polar_decode(angles, radius, head_size, vec_out); + + // Step 2: Add QJL correction (for tq2/tq3 modes) + if constexpr (has_qjl(DT)) { + float residual_norm = __half2float(residual_norm_fp16); + float correction[MAX_HEAD_SIZE]; + qjl_decode_correction(qjl_bits, residual_norm, head_size, qjl_seed, + qjl_proj_dim, correction); + for (int i = 0; i < head_size; i++) { + vec_out[i] += correction[i]; + } + } + + // Step 3: Inverse randomized Hadamard rotation + inverse_randomized_hadamard(vec_out, head_size, rotation_seed); +} + +// ============================================================================ +// Block layout utilities +// ============================================================================ + +// Per-token-per-head layout (contiguous in memory): +// +// [angles: padded_angle_bytes] [radius: 2B fp16] [qjl_bits: qjl_bytes] +// [residual_norm: 2B fp16 (QJL modes only)] +// +// The angle bytes are padded to the next even number to ensure the radius +// fp16 is always 2-byte aligned. +// +// PQ4: angles(padded) + radius(2) +// TQ3: angles(padded) + radius(2) + qjl_bits + residual_norm(2) +// TQ2: angles(padded) + radius(2) + qjl_bits + residual_norm(2) + +// Compute the padded angle byte count (rounded up to even for alignment). +__host__ __device__ inline int padded_angle_bytes(int head_size, int bits) { + int raw = ((head_size - 1) * bits + 7) / 8; + return (raw + 1) & ~1; // Round up to next even number +} + +// Calculate total bytes per token per head for a TurboQuant mode. +__host__ __device__ int turboquant_bytes_per_token_per_head( + TQDataType dt, int head_size) { + int bits = angle_bits(dt); + int angle_bytes = padded_angle_bytes(head_size, bits); + int radius_bytes = 2; // fp16 + int qjl_bytes = has_qjl(dt) ? (head_size + 7) / 8 : 0; + int residual_norm_bytes = has_qjl(dt) ? 2 : 0; // fp16 + return angle_bytes + radius_bytes + qjl_bytes + residual_norm_bytes; +} + +// Calculate total bytes per block for a TurboQuant mode. +__host__ __device__ int turboquant_block_bytes(TQDataType dt, int num_kv_heads, + int head_size, int block_size) { + return num_kv_heads * block_size * + turboquant_bytes_per_token_per_head(dt, head_size); +} + +// Per-head seed derivation from layer-level seed +__device__ __forceinline__ uint32_t derive_rotation_seed(uint32_t layer_seed, + int head_idx) { + return layer_seed ^ (static_cast(head_idx) * 2654435761u); +} + +__device__ __forceinline__ uint32_t derive_qjl_seed(uint32_t layer_seed, + int head_idx) { + return layer_seed ^ (static_cast(head_idx) * 2246822519u); +} + +} // namespace turboquant +} // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4f42477b2f6c..2e86b0891952 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -542,6 +542,50 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA, &concat_and_cache_mla_rope_fused); + // TurboQuant: reshape and cache with PolarQuant + QJL quantization. + cache_ops.def( + "reshape_and_cache_turboquant(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " int num_kv_heads, int head_size," + " int block_size, str tq_type," + " int layer_seed, int qjl_proj_dim) -> ()"); + cache_ops.impl("reshape_and_cache_turboquant", torch::kCUDA, + &reshape_and_cache_turboquant); + + // TurboQuant: standalone encode (for testing). + cache_ops.def( + "turboquant_encode(Tensor kv_data, Tensor! angles_out," + " Tensor! radii_out, Tensor! qjl_out," + " Tensor! residual_norms_out," + " int num_kv_heads, int head_size," + " str tq_type, int layer_seed," + " int qjl_proj_dim) -> ()"); + cache_ops.impl("turboquant_encode", torch::kCUDA, &turboquant_encode); + + // TurboQuant: standalone decode (for testing). + cache_ops.def( + "turboquant_decode(Tensor angles, Tensor radii," + " Tensor qjl_bits, Tensor residual_norms," + " Tensor! kv_out," + " int num_kv_heads, int head_size," + " str tq_type, int layer_seed," + " int qjl_proj_dim) -> ()"); + cache_ops.impl("turboquant_decode", torch::kCUDA, &turboquant_decode); + + // TurboQuant: fused paged attention with on-the-fly dequantization. + cache_ops.def( + "paged_attention_turboquant(Tensor! output, Tensor query," + " Tensor key_cache, Tensor value_cache," + " Tensor block_tables, Tensor context_lens," + " float scale, int num_heads," + " int num_kv_heads, int head_size," + " int block_size, int max_blocks_per_seq," + " str tq_type, int layer_seed," + " int qjl_proj_dim) -> ()"); + cache_ops.impl("paged_attention_turboquant", torch::kCUDA, + &paged_attention_turboquant); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/tests/kernels/quantization/test_turboquant.py b/tests/kernels/quantization/test_turboquant.py new file mode 100644 index 000000000000..d4ba2f31e6db --- /dev/null +++ b/tests/kernels/quantization/test_turboquant.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for TurboQuant KV cache quantization (PolarQuant + QJL). + +Tests cover: + 1. Numerical round-trip accuracy (encode -> decode) + 2. QJL unbiasedness (mean error converges to zero) + 3. Deterministic reconstruction (same seed = identical output) + 4. TurboQuantConfig utilities +""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization.turboquant import ( + TurboQuantConfig, + is_turboquant_kv_cache, +) + + +# ============================================================================ +# Unit tests for TurboQuantConfig (no GPU required) +# ============================================================================ + + +class TestTurboQuantConfig: + def test_from_kv_cache_dtype_pq4(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("pq4") + assert cfg.angle_bits == 4 + assert cfg.qjl_residual is False + assert cfg.qjl_projection_dim is None + + def test_from_kv_cache_dtype_tq3(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq3") + assert cfg.angle_bits == 3 + assert cfg.qjl_residual is True + + def test_from_kv_cache_dtype_tq2(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq2") + assert cfg.angle_bits == 2 + assert cfg.qjl_residual is True + + def test_invalid_dtype_raises(self): + with pytest.raises(ValueError, match="Unknown TurboQuant dtype"): + TurboQuantConfig.from_kv_cache_dtype("fp8") + + def test_effective_bits_pq4(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("pq4") + # For head_size=128: (127*4 + 16) / 128 = 4.09375 + bits = cfg.effective_bits_per_element(128) + assert 4.0 < bits < 4.2 + + def test_effective_bits_tq3(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq3") + # For head_size=128: (127*3 + 16 + 128) / 128 = 4.1015625 + bits = cfg.effective_bits_per_element(128) + assert 4.0 < bits < 4.2 + + def test_effective_bits_tq2(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq2") + # For head_size=128: (127*2 + 16 + 128) / 128 = 3.109375 + bits = cfg.effective_bits_per_element(128) + assert 3.0 < bits < 3.2 + + def test_bytes_per_token_per_head(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq3") + # For head_size=128: angle_bytes=48 (127*3=381 bits -> 48 bytes, + # padded to 48 which is already even), radius=2, qjl=16 (128/8), + # residual_norm=2 -> total=68 + bpt = cfg.bytes_per_token_per_head(128) + assert bpt == 48 + 2 + 16 + 2 # 68 bytes + + def test_block_bytes(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq3") + # 8 kv heads, head_size=128, block_size=16 + bpt = cfg.bytes_per_token_per_head(128) + bb = cfg.block_bytes(num_kv_heads=8, head_size=128, block_size=16) + assert bb == 8 * 16 * bpt + + def test_derive_layer_seed_deterministic(self): + cfg = TurboQuantConfig.from_kv_cache_dtype("tq3", rotation_seed=42) + s1 = cfg.derive_layer_seed(0) + s2 = cfg.derive_layer_seed(0) + assert s1 == s2 + # Different layers get different seeds + s3 = cfg.derive_layer_seed(1) + assert s1 != s3 + + def test_is_turboquant_kv_cache(self): + assert is_turboquant_kv_cache("tq2") is True + assert is_turboquant_kv_cache("tq3") is True + assert is_turboquant_kv_cache("pq4") is True + assert is_turboquant_kv_cache("fp8") is False + assert is_turboquant_kv_cache("auto") is False + assert is_turboquant_kv_cache("float16") is False + + +# ============================================================================ +# CUDA kernel tests (require GPU) +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestTurboQuantKernels: + """Tests for the CUDA encode/decode kernels. + + These tests use the standalone turboquant_encode/decode ops for + round-trip testing. They require the vLLM CUDA extensions to be compiled. + """ + + def _try_import_ops(self): + """Try to import vLLM custom ops. Skip if not compiled.""" + try: + from vllm import _custom_ops as ops + + # Check if the turboquant ops are available + if not hasattr(torch.ops, "_C_cache_ops"): + pytest.skip("vLLM CUDA extensions not compiled") + if not hasattr(torch.ops._C_cache_ops, "turboquant_encode"): + pytest.skip("TurboQuant CUDA kernels not compiled") + return ops + except (ImportError, AttributeError): + pytest.skip("vLLM CUDA extensions not available") + + @pytest.mark.parametrize("tq_type", ["pq4", "tq3", "tq2"]) + @pytest.mark.parametrize("head_size", [64, 128]) + def test_roundtrip_accuracy(self, tq_type, head_size): + """Encode then decode should produce values close to the original.""" + ops = self._try_import_ops() + + num_tokens = 4 + num_kv_heads = 8 + layer_seed = 42 + qjl_proj_dim = head_size + + # Random KV data + kv_data = torch.randn( + num_tokens, num_kv_heads, head_size, + device="cuda", dtype=torch.float32, + ) + + # Allocate output buffers + num_angles = head_size - 1 + angle_bits_map = {"pq4": 4, "tq3": 3, "tq2": 2} + bits = angle_bits_map[tq_type] + angle_bytes = (num_angles * bits + 7) // 8 + has_qjl = tq_type.startswith("tq") + qjl_bytes = (qjl_proj_dim + 7) // 8 if has_qjl else 0 + + angles = torch.zeros( + num_tokens * num_kv_heads, angle_bytes, + device="cuda", dtype=torch.uint8, + ) + radii = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + qjl_out = torch.zeros( + num_tokens * num_kv_heads, max(qjl_bytes, 1), + device="cuda", dtype=torch.uint8, + ) + residual_norms = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + kv_out = torch.zeros_like(kv_data) + + # Encode + ops.turboquant_encode( + kv_data, angles, radii, qjl_out, residual_norms, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + + # Decode + ops.turboquant_decode( + angles, radii, qjl_out, residual_norms, kv_out, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + + # Check reconstruction error + max_abs_error = (kv_data - kv_out).abs().max().item() + mean_abs_error = (kv_data - kv_out).abs().mean().item() + + # Error bounds depend on quantization level + if tq_type == "pq4": + assert max_abs_error < 0.5, ( + f"pq4 max abs error {max_abs_error} too large" + ) + elif tq_type == "tq3": + assert max_abs_error < 1.0, ( + f"tq3 max abs error {max_abs_error} too large" + ) + elif tq_type == "tq2": + assert max_abs_error < 2.0, ( + f"tq2 max abs error {max_abs_error} too large" + ) + + # Mean error should be reasonable + assert mean_abs_error < 0.5, ( + f"{tq_type} mean abs error {mean_abs_error} too large" + ) + + @pytest.mark.parametrize("tq_type", ["tq3", "tq2"]) + def test_qjl_unbiasedness(self, tq_type): + """QJL residual should produce an unbiased correction on average.""" + ops = self._try_import_ops() + + num_tokens = 100 + num_kv_heads = 4 + head_size = 64 + layer_seed = 42 + qjl_proj_dim = head_size + + # Generate random data + kv_data = torch.randn( + num_tokens, num_kv_heads, head_size, + device="cuda", dtype=torch.float32, + ) + + # Encode + decode + angle_bits_map = {"tq3": 3, "tq2": 2} + bits = angle_bits_map[tq_type] + num_angles = head_size - 1 + angle_bytes = (num_angles * bits + 7) // 8 + qjl_bytes = (qjl_proj_dim + 7) // 8 + + angles = torch.zeros( + num_tokens * num_kv_heads, angle_bytes, + device="cuda", dtype=torch.uint8, + ) + radii = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + qjl_out = torch.zeros( + num_tokens * num_kv_heads, qjl_bytes, + device="cuda", dtype=torch.uint8, + ) + residual_norms = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + kv_out = torch.zeros_like(kv_data) + + ops.turboquant_encode( + kv_data, angles, radii, qjl_out, residual_norms, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + ops.turboquant_decode( + angles, radii, qjl_out, residual_norms, kv_out, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + + # Mean error across all tokens should converge toward zero + mean_error = (kv_data - kv_out).mean().item() + assert abs(mean_error) < 0.1, ( + f"QJL mean error {mean_error} suggests bias (should be ~0)" + ) + + @pytest.mark.parametrize("tq_type", ["pq4", "tq3", "tq2"]) + def test_deterministic_reconstruction(self, tq_type): + """Same input + same seed should produce identical output.""" + ops = self._try_import_ops() + + num_tokens = 4 + num_kv_heads = 4 + head_size = 64 + layer_seed = 123 + qjl_proj_dim = head_size + + kv_data = torch.randn( + num_tokens, num_kv_heads, head_size, + device="cuda", dtype=torch.float32, + ) + + angle_bits_map = {"pq4": 4, "tq3": 3, "tq2": 2} + bits = angle_bits_map[tq_type] + num_angles = head_size - 1 + angle_bytes = (num_angles * bits + 7) // 8 + has_qjl = tq_type.startswith("tq") + qjl_bytes = (qjl_proj_dim + 7) // 8 if has_qjl else 0 + + def encode_decode(): + angles = torch.zeros( + num_tokens * num_kv_heads, angle_bytes, + device="cuda", dtype=torch.uint8, + ) + radii = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + qjl_out = torch.zeros( + num_tokens * num_kv_heads, max(qjl_bytes, 1), + device="cuda", dtype=torch.uint8, + ) + residual_norms = torch.zeros( + num_tokens * num_kv_heads, + device="cuda", dtype=torch.float16, + ) + kv_out = torch.zeros_like(kv_data) + + ops.turboquant_encode( + kv_data, angles, radii, qjl_out, residual_norms, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + ops.turboquant_decode( + angles, radii, qjl_out, residual_norms, kv_out, + num_kv_heads, head_size, tq_type, layer_seed, qjl_proj_dim, + ) + return kv_out + + out1 = encode_decode() + out2 = encode_decode() + + assert torch.equal(out1, out2), ( + f"Deterministic reconstruction failed for {tq_type}: " + f"max diff = {(out1 - out2).abs().max().item()}" + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c55f5b92327f..a541384a2232 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2534,6 +2534,122 @@ def reshape_and_cache_flash( ) +def reshape_and_cache_turboquant( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + num_kv_heads: int, + head_size: int, + block_size: int, + tq_type: str, + layer_seed: int, + qjl_proj_dim: int, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_turboquant( + key, + value, + key_cache, + value_cache, + slot_mapping, + num_kv_heads, + head_size, + block_size, + tq_type, + layer_seed, + qjl_proj_dim, + ) + + +def turboquant_encode( + kv_data: torch.Tensor, + angles_out: torch.Tensor, + radii_out: torch.Tensor, + qjl_out: torch.Tensor, + residual_norms_out: torch.Tensor, + num_kv_heads: int, + head_size: int, + tq_type: str, + layer_seed: int, + qjl_proj_dim: int, +) -> None: + torch.ops._C_cache_ops.turboquant_encode( + kv_data, + angles_out, + radii_out, + qjl_out, + residual_norms_out, + num_kv_heads, + head_size, + tq_type, + layer_seed, + qjl_proj_dim, + ) + + +def turboquant_decode( + angles: torch.Tensor, + radii: torch.Tensor, + qjl_bits: torch.Tensor, + residual_norms: torch.Tensor, + kv_out: torch.Tensor, + num_kv_heads: int, + head_size: int, + tq_type: str, + layer_seed: int, + qjl_proj_dim: int, +) -> None: + torch.ops._C_cache_ops.turboquant_decode( + angles, + radii, + qjl_bits, + residual_norms, + kv_out, + num_kv_heads, + head_size, + tq_type, + layer_seed, + qjl_proj_dim, + ) + + +def paged_attention_turboquant( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + num_heads: int, + num_kv_heads: int, + head_size: int, + block_size: int, + max_blocks_per_seq: int, + tq_type: str, + layer_seed: int, + qjl_proj_dim: int, +) -> None: + torch.ops._C_cache_ops.paged_attention_turboquant( + output, + query, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + num_heads, + num_kv_heads, + head_size, + block_size, + max_blocks_per_seq, + tq_type, + layer_seed, + qjl_proj_dim, + ) + + def concat_and_cache_mla( kv_c: torch.Tensor, k_pe: torch.Tensor, diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 49c8868e709f..c39494263c8f 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -20,6 +20,9 @@ "fp8_e5m2", "fp8_inc", "fp8_ds_mla", + "tq2", + "tq3", + "pq4", ] MambaDType = Literal["auto", "float32", "float16"] MambaCacheMode = Literal["all", "align", "none"] @@ -243,6 +246,18 @@ def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: "Meanwhile, it may cause accuracy drop without a proper " "scaling factor." ) + elif cache_dtype in ("tq2", "tq3", "pq4"): + bits_info = {"tq2": "2+1-bit", "tq3": "3+1-bit", "pq4": "4-bit"} + logger.info( + "Using TurboQuant (%s, %s) KV cache quantization. " + "This uses PolarQuant polar coordinate decomposition%s " + "for overhead-free KV compression. No calibration data " + "or fine-tuning required.", + cache_dtype, + bits_info[cache_dtype], + " with QJL residual correction" + if cache_dtype.startswith("tq") else "", + ) return cache_dtype def __post_init__(self): diff --git a/vllm/model_executor/layers/quantization/turboquant.py b/vllm/model_executor/layers/quantization/turboquant.py new file mode 100644 index 000000000000..e6ba7a50456e --- /dev/null +++ b/vllm/model_executor/layers/quantization/turboquant.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +TurboQuant KV cache quantization. + +Implements PolarQuant polar coordinate decomposition and QJL (Quantized +Johnson-Lindenstrauss) sign-bit residual correction for overhead-free 2-bit +and 3-bit KV cache quantization. + +Reference: Zandieh et al., "TurboQuant: Redefining AI Efficiency with +Extreme Compression", ICLR 2026 (arXiv:2504.19874) + +Supported KV cache dtypes: + - pq4: PolarQuant 4-bit angles, no QJL (fastest encode, ~4x compression) + - tq3: PolarQuant 3-bit + 1-bit QJL residual (best quality/compression, ~4x) + - tq2: PolarQuant 2-bit + 1-bit QJL residual (maximum compression, ~5.3x) + +Usage: + vllm serve --kv-cache-dtype tq3 +""" + +from dataclasses import dataclass + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# TurboQuant dtypes recognized by the system +TURBOQUANT_DTYPES = ("tq2", "tq3", "pq4") + + +def is_turboquant_kv_cache(kv_cache_dtype: str) -> bool: + """Check if the KV cache dtype is a TurboQuant type.""" + return kv_cache_dtype in TURBOQUANT_DTYPES + + +@dataclass +class TurboQuantConfig: + """Configuration for TurboQuant KV cache quantization. + + This is not a weight quantization config (not a QuantizationConfig subclass) + — TurboQuant operates purely on the KV cache at inference time with no + model weight changes and no calibration data. + """ + + kv_cache_dtype: str + """One of 'tq2', 'tq3', 'pq4'.""" + + qjl_residual: bool + """Whether QJL residual correction is enabled (True for tq2/tq3).""" + + angle_bits: int + """Number of bits for quantized angles (2, 3, or 4).""" + + qjl_projection_dim: int | None + """Number of QJL sign-bit projections. Defaults to head_dim.""" + + rotation_seed: int + """Base seed for per-layer Hadamard rotation. Derived per-head.""" + + @classmethod + def from_kv_cache_dtype( + cls, + kv_cache_dtype: str, + rotation_seed: int = 42, + qjl_projection_dim: int | None = None, + ) -> "TurboQuantConfig": + """Create config from a KV cache dtype string.""" + if kv_cache_dtype == "pq4": + return cls( + kv_cache_dtype=kv_cache_dtype, + qjl_residual=False, + angle_bits=4, + qjl_projection_dim=None, + rotation_seed=rotation_seed, + ) + elif kv_cache_dtype == "tq3": + return cls( + kv_cache_dtype=kv_cache_dtype, + qjl_residual=True, + angle_bits=3, + qjl_projection_dim=qjl_projection_dim, + rotation_seed=rotation_seed, + ) + elif kv_cache_dtype == "tq2": + return cls( + kv_cache_dtype=kv_cache_dtype, + qjl_residual=True, + angle_bits=2, + qjl_projection_dim=qjl_projection_dim, + rotation_seed=rotation_seed, + ) + else: + raise ValueError( + f"Unknown TurboQuant dtype: {kv_cache_dtype}. " + f"Supported: {TURBOQUANT_DTYPES}" + ) + + def effective_bits_per_element(self, head_size: int) -> float: + """Calculate effective bits per KV element including all overhead. + + For a head of dimension d: + - angles: (d-1) * angle_bits + - radius: 16 bits (fp16) + - QJL: d bits (1 bit per projection, proj_dim defaults to d) + - Total per head per token: (d-1)*angle_bits + 16 + d*qjl + - Per element: total / d + """ + d = head_size + qjl_dim = self.qjl_projection_dim or d + angle_total = (d - 1) * self.angle_bits + radius_total = 16 # fp16 + qjl_total = qjl_dim if self.qjl_residual else 0 + total_bits = angle_total + radius_total + qjl_total + return total_bits / d + + def _padded_angle_bytes(self, head_size: int) -> int: + """Angle bytes padded to next even number for fp16 alignment.""" + raw = ((head_size - 1) * self.angle_bits + 7) // 8 + return (raw + 1) & ~1 # Round up to even + + def bytes_per_token_per_head(self, head_size: int) -> int: + """Calculate storage bytes per token per KV head.""" + d = head_size + qjl_dim = self.qjl_projection_dim or d + angle_bytes = self._padded_angle_bytes(d) + radius_bytes = 2 # fp16 + qjl_bytes = (qjl_dim + 7) // 8 if self.qjl_residual else 0 + residual_norm_bytes = 2 if self.qjl_residual else 0 # fp16 + return angle_bytes + radius_bytes + qjl_bytes + residual_norm_bytes + + def block_bytes( + self, num_kv_heads: int, head_size: int, block_size: int + ) -> int: + """Calculate total bytes per KV cache block.""" + return ( + num_kv_heads + * block_size + * self.bytes_per_token_per_head(head_size) + ) + + def derive_layer_seed(self, layer_idx: int) -> int: + """Derive a per-layer rotation seed from the base seed.""" + # Use golden ratio hash for good distribution + return self.rotation_seed ^ (layer_idx * 2654435761 & 0xFFFFFFFF) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index bd9741024f2a..60a61eacc5f9 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -39,6 +39,9 @@ "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, + "tq2": torch.uint8, + "tq3": torch.uint8, + "pq4": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index ec21b0fe9110..11e7e0fb2d40 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -955,7 +955,9 @@ def do_kv_cache_update( def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: - return kv_cache_dtype.startswith("fp8") + return kv_cache_dtype.startswith("fp8") or kv_cache_dtype in ( + "tq2", "tq3", "pq4" + ) def subclass_attention_backend( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 48ecf6b9dc85..f08e60422c65 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -187,6 +187,52 @@ def real_page_size_bytes(self) -> int: ) +@dataclass(frozen=True, kw_only=True) +class TurboQuantAttentionSpec(FullAttentionSpec): + """KV cache spec for TurboQuant (PolarQuant + QJL) quantized cache. + + Overrides page size calculation to account for the compressed polar + coordinate representation instead of raw element storage. + """ + + tq_type: str = "tq3" # "tq2", "tq3", or "pq4" + qjl_proj_dim: int | None = None # defaults to head_size + + @property + def real_page_size_bytes(self) -> int: + """Page size for TurboQuant compressed KV cache. + + Each token per KV head stores: + - (head_size - 1) quantized angles (padded to even bytes) + - 1 fp16 radius (2 bytes) + - ceil(qjl_proj_dim / 8) QJL sign bytes (if tq2/tq3) + - 1 fp16 residual norm (2 bytes, if tq2/tq3) + Multiply by 2 for both K and V caches. + """ + angle_bits_map = {"pq4": 4, "tq3": 3, "tq2": 2} + has_qjl = self.tq_type.startswith("tq") + bits = angle_bits_map[self.tq_type] + qjl_dim = self.qjl_proj_dim or self.head_size + + # Pad angle bytes to even for fp16 alignment + raw_angle_bytes = ((self.head_size - 1) * bits + 7) // 8 + angle_bytes = (raw_angle_bytes + 1) & ~1 + radius_bytes = 2 # fp16 + qjl_bytes = (qjl_dim + 7) // 8 if has_qjl else 0 + residual_norm_bytes = 2 if has_qjl else 0 # fp16 + bytes_per_token_per_head = ( + angle_bytes + radius_bytes + qjl_bytes + residual_norm_bytes + ) + + # Factor of 2 for both key and value caches + return ( + 2 + * self.block_size + * self.num_kv_heads + * bytes_per_token_per_head + ) + + @dataclass(frozen=True, kw_only=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this