diff --git a/flashinfer/page.py b/flashinfer/page.py index 22b6c17666..12ea36137f 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -178,12 +178,13 @@ def get_batch_indices_positions( dtype = torch.int32 if batch_indices is None: - batch_indices = torch.empty((nnz,), device=device, dtype=dtype) + batch_indices = torch.full((nnz,), -1, device=device, dtype=dtype) else: check_shape_dtype_device(batch_indices, (nnz,), dtype, device, "batch_indices") + batch_indices.fill_(-1) if positions is None: - positions = torch.empty((nnz,), device=device, dtype=dtype) + positions = torch.zeros((nnz,), device=device, dtype=dtype) else: check_shape_dtype_device(positions, (nnz,), dtype, device, "positions") diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 4fdd75e0a3..614a34a4d8 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -859,169 +859,175 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( const uint32_t idx = bx * bdy + ty; const RoPEIdType pos = pos_ids[idx]; - // Compute page location for this token - uint32_t page_iter, entry_idx; - paged_kv_like.page_size.divmod( - paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], - page_iter, entry_idx); - - const int half_rope_dim = rope_dim / 2; - // Load cos/sin for RoPE processing blocks only - if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { - int sin_offset = rope_dim / 2; - int vec_idx; - if constexpr (interleave) { - vec_idx = (tx * vec_size) / 2; // Force integer division - } else { - vec_idx = (tx * vec_size) % half_rope_dim; + // skip padding tokens with batch_indices < 0 + if (batch_indices[idx] >= 0) { + // Compute page location for this token + uint32_t page_iter, entry_idx; + paged_kv_like.page_size.divmod( + paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], + page_iter, entry_idx); + + const int half_rope_dim = rope_dim / 2; + // Load cos/sin for RoPE processing blocks only + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { + int sin_offset = rope_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rope_dim; + } + cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); } - cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); - sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); - } - if (by < q_rope_end) { - // ============ Q RoPE processing ============ - uint32_t q_head_idx = by / rope_chunks; - uint32_t rope_chunk_idx = by % rope_chunks; - uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + if (by < q_rope_end) { + // ============ Q RoPE processing ============ + uint32_t q_head_idx = by / rope_chunks; + uint32_t rope_chunk_idx = by % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; - DType* q_rope_in_ptr = - q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, - q_rope_in_stride_h); - QuantType* q_rope_out_ptr = - q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, - q_rope_out_stride_h); + DType* q_rope_in_ptr = + q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, + q_rope_in_stride_h); + QuantType* q_rope_out_ptr = + q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, + q_rope_out_stride_h); - vec_t q_rope_vec; - if constexpr (interleave) { - q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( - q_rope_in_ptr, cos, sin, rope_dim); - } else { - q_rope_vec = vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); - } + vec_t q_rope_vec; + if constexpr (interleave) { + q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + q_rope_in_ptr, cos, sin, rope_dim); + } else { + q_rope_vec = + vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); + } #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; - } - q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); - - } else if (by < k_rope_end) { - // ============ K RoPE processing & Cache Append ============ - uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; - uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; - uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; - - DType* k_rope_in_ptr; - if constexpr (IS_MLA) { - // MLA: 2D K - k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset; - } else { - // GQA/MHA: 3D K - k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, - k_rope_in_stride, k_rope_in_stride_h); - } + for (uint32_t i = 0; i < vec_size; ++i) { + q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + } + q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); + + } else if (by < k_rope_end) { + // ============ K RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; + uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* k_rope_in_ptr; + if constexpr (IS_MLA) { + // MLA: 2D K + k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset; + } else { + // GQA/MHA: 3D K + k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_rope_in_stride, k_rope_in_stride_h); + } - vec_t k_rope_vec; - if constexpr (interleave) { - k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( - k_rope_in_ptr, cos, sin, rope_dim); - } else { - k_rope_vec = vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); - } + vec_t k_rope_vec; + if constexpr (interleave) { + k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + k_rope_in_ptr, cos, sin, rope_dim); + } else { + k_rope_vec = + vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); + } #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; - } + for (uint32_t i = 0; i < vec_size; ++i) { + k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + } - if constexpr (IS_MLA) { - QuantType* kpe_ptr = - paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); - k_rope_vec.cast_store(kpe_ptr); - } else { - QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); - k_rope_vec.cast_store(k_ptr); - } + if constexpr (IS_MLA) { + QuantType* kpe_ptr = + paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_rope_vec.cast_store(kpe_ptr); + } else { + QuantType* k_ptr = + paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); + k_rope_vec.cast_store(k_ptr); + } - } else if (by < k_nope_end) { - // ============ K Non-RoPE processing & Cache Append ============ - uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; - uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; - uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + } else if (by < k_nope_end) { + // ============ K Non-RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; - DType* k_nope_in_ptr; - if constexpr (IS_MLA) { - k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; - } else { - k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, - k_nope_in_stride, k_nope_in_stride_h); - } + DType* k_nope_in_ptr; + if constexpr (IS_MLA) { + k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; + } else { + k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_nope_in_stride, k_nope_in_stride_h); + } - vec_t k_nope_vec; - k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); + vec_t k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; - } + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + } - if constexpr (IS_MLA) { - QuantType* ckv_ptr = - paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); - k_nope_vec.cast_store(ckv_ptr); - } else { - QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, - rope_dim + elem_offset + tx * vec_size); - k_nope_vec.cast_store(k_ptr); - } + if constexpr (IS_MLA) { + QuantType* ckv_ptr = + paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_nope_vec.cast_store(ckv_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, + rope_dim + elem_offset + tx * vec_size); + k_nope_vec.cast_store(k_ptr); + } - } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) { - // ============ V processing & Cache Append (GQA/MHA only) ============ - if constexpr (!IS_MLA) { - uint32_t kv_head_idx = by - k_nope_end; - DType* v_in_ptr = - v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); - // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size - uint32_t head_dim_total = rope_dim + no_rope_dim; - uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; + } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) { + // ============ V processing & Cache Append (GQA/MHA only) ============ + if constexpr (!IS_MLA) { + uint32_t kv_head_idx = by - k_nope_end; + DType* v_in_ptr = + v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); + // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size + uint32_t head_dim_total = rope_dim + no_rope_dim; + uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; #pragma unroll 1 - for (uint32_t j = 0; j < v_chunks; ++j) { - uint32_t v_elem_offset = j * rope_chunk_size; - if (v_elem_offset + tx * vec_size < head_dim_total) { - vec_t v_vec; - v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); + for (uint32_t j = 0; j < v_chunks; ++j) { + uint32_t v_elem_offset = j * rope_chunk_size; + if (v_elem_offset + tx * vec_size < head_dim_total) { + vec_t v_vec; + v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v_vec[i] = v_vec[i] * quant_scale_kv; + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * quant_scale_kv; + } + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, + v_elem_offset + tx * vec_size); + v_vec.cast_store(v_ptr); } - QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, - v_elem_offset + tx * vec_size); - v_vec.cast_store(v_ptr); } } - } - } else { - // ============ Q Non-RoPE processing ============ - // MLA has no V section, so Q-nope starts immediately after K-nope. - // GQA/MHA has a V section of length num_kv_heads blocks. - uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); - uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; - uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; - uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; - - DType* q_nope_in_ptr = - q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, - q_nope_in_stride_h); - QuantType* q_nope_out_ptr = - q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, - q_nope_out_stride_h); - - vec_t q_nope_vec; - q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); + } else { + // ============ Q Non-RoPE processing ============ + // MLA has no V section, so Q-nope starts immediately after K-nope. + // GQA/MHA has a V section of length num_kv_heads blocks. + uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); + uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* q_nope_in_ptr = + q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, + q_nope_in_stride_h); + QuantType* q_nope_out_ptr = + q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, + q_nope_out_stride_h); + + vec_t q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + for (uint32_t i = 0; i < vec_size; ++i) { + q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + } + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); } - q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 75570223ae..05e58ed882 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -1380,6 +1380,267 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( ) +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + ("gqa", 64, 8, 128, 0), + ("mla", 128, 1, 64, 512), + ], +) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16]) +def test_rope_quantize_fp8_append_paged_kv_cache_padding( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + kv_layout, + page_size, +): + """Test that CUDA graph padding tokens (batch_indices=-1) do not corrupt + the KV cache. + + Simulates a decode batch with 3 real requests padded to 5 (as happens + with FULL CUDA graphs). get_batch_indices_positions fills padding entries + with batch_indices=-1. The fused kernel must skip those entries so that + previously-written KV cache data is preserved. + """ + device = "cuda:0" + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + head_dim = rope_dim + no_rope_dim + input_dtype = torch.bfloat16 + quant_dtype = torch.float8_e4m3fn + + num_real_reqs = 3 + num_padded_reqs = 5 + num_real_tokens = num_real_reqs # 1 decode token per request + num_padded_tokens = num_padded_reqs + real_seq_lens = [30, 50, 70] + + max_seq_len = 256 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + + # --- Build paged KV metadata --- + pages_per_req = [(s + page_size - 1) // page_size for s in real_seq_lens] + total_real_pages = sum(pages_per_req) + max_pages = total_real_pages + 10 + + kv_indptr_list = [0] + for p in pages_per_req: + kv_indptr_list.append(kv_indptr_list[-1] + p) + for _ in range(num_padded_reqs - num_real_reqs): + kv_indptr_list.append(kv_indptr_list[-1]) + + kv_indptr = torch.tensor(kv_indptr_list, dtype=torch.int32, device=device) + kv_indices = torch.arange(max_pages, dtype=torch.int32, device=device) + + seq_lens = torch.zeros(num_padded_reqs, dtype=torch.int32, device=device) + for i, s in enumerate(real_seq_lens): + seq_lens[i] = s + + # qo_indptr: [0,1,2,3,3,3] — 3 real tokens, 2 padded (0 tokens each) + qo_indptr = torch.zeros(num_padded_reqs + 1, dtype=torch.int32, device=device) + for i in range(num_real_reqs): + qo_indptr[i + 1] = qo_indptr[i] + 1 + for i in range(num_real_reqs, num_padded_reqs): + qo_indptr[i + 1] = qo_indptr[i] + + # --- Allocate KV cache and pre-fill with "prior prefill" data --- + if attention_type == "mla": + ckv_cache = torch.randn( + max_pages, + page_size, + no_rope_dim, + dtype=input_dtype, + device=device, + ).to(quant_dtype) + kpe_cache = torch.randn( + max_pages, + page_size, + rope_dim, + dtype=input_dtype, + device=device, + ).to(quant_dtype) + ckv_cache_snapshot = ckv_cache.clone() + kpe_cache_snapshot = kpe_cache.clone() + paged_kv_cache = (ckv_cache, kpe_cache) + else: + if kv_layout == "NHD": + cache_shape = (max_pages, page_size, num_kv_heads, head_dim) + else: + cache_shape = (max_pages, num_kv_heads, page_size, head_dim) + k_cache = torch.randn(*cache_shape, dtype=input_dtype, device=device).to( + quant_dtype + ) + v_cache = torch.randn(*cache_shape, dtype=input_dtype, device=device).to( + quant_dtype + ) + k_cache_snapshot = k_cache.clone() + v_cache_snapshot = v_cache.clone() + paged_kv_cache = (k_cache, v_cache) + + # --- Build model tensors (padded size) --- + if attention_type == "mla": + q_rope = torch.randn( + num_padded_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope = ( + torch.randn( + num_padded_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + if no_rope_dim > 0 + else None + ) + k_rope = torch.randn( + num_padded_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope = ( + torch.randn( + num_padded_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + v_in = None + else: + q_rope = torch.randn( + num_padded_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope = None + k_rope = torch.randn( + num_padded_tokens, + num_kv_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + k_nope = None + v_in = torch.randn( + num_padded_tokens, + num_kv_heads, + head_dim, + dtype=input_dtype, + device=device, + ) + + pos_ids = torch.zeros(num_padded_tokens, dtype=torch.int32, device=device) + for i in range(num_real_reqs): + pos_ids[i] = real_seq_lens[i] - 1 + + # --- Run get_batch_indices_positions (should fill padding with -1) --- + batch_indices = torch.empty(num_padded_tokens, dtype=torch.int32, device=device) + paged_positions = torch.empty(num_padded_tokens, dtype=torch.int32, device=device) + # Pre-fill with stale data to simulate persistent CUDA graph buffers + batch_indices.fill_(0) + paged_positions.fill_(0) + + flashinfer.get_batch_indices_positions( + qo_indptr, + seq_lens, + num_padded_tokens, + batch_indices, + paged_positions, + ) + + # Verify padding entries are filled with -1 + assert batch_indices[num_real_tokens:].eq(-1).all(), ( + f"Expected batch_indices[{num_real_tokens}:] == -1, " + f"got {batch_indices[num_real_tokens:]}" + ) + + # --- Run the fused kernel --- + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=v_in, + cos_sin_cache=rope_ref.cos_sin_cache, + pos_ids=pos_ids, + paged_kv_cache=paged_kv_cache, + kv_indices=kv_indices, + kv_indptr=kv_indptr, + batch_indices=batch_indices, + positions=paged_positions, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + ) + + # --- Verify: prefill data at non-decode positions is preserved --- + # The decode writes to position seq_lens[i]-1 for each real request. + # All other prefill positions must remain unchanged. + for req_idx in range(num_real_reqs): + sl = real_seq_lens[req_idx] + for pos in [0, sl // 2]: # check position 0 and a mid-position + page_offset = pos // page_size + entry_offset = pos % page_size + page_iter = kv_indptr[req_idx].item() + page_offset + page_idx = kv_indices[page_iter].item() + + if attention_type == "mla": + torch.testing.assert_close( + ckv_cache[page_idx, entry_offset].float(), + ckv_cache_snapshot[page_idx, entry_offset].float(), + rtol=0, + atol=0, + msg=f"Req {req_idx} pos {pos}: CKV cache corrupted by padding", + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_offset].float(), + kpe_cache_snapshot[page_idx, entry_offset].float(), + rtol=0, + atol=0, + msg=f"Req {req_idx} pos {pos}: KPE cache corrupted by padding", + ) + else: + if kv_layout == "NHD": + k_entry = k_cache[page_idx, entry_offset] + k_snap = k_cache_snapshot[page_idx, entry_offset] + v_entry = v_cache[page_idx, entry_offset] + v_snap = v_cache_snapshot[page_idx, entry_offset] + else: + k_entry = k_cache[page_idx, :, entry_offset] + k_snap = k_cache_snapshot[page_idx, :, entry_offset] + v_entry = v_cache[page_idx, :, entry_offset] + v_snap = v_cache_snapshot[page_idx, :, entry_offset] + torch.testing.assert_close( + k_entry.float(), + k_snap.float(), + rtol=0, + atol=0, + msg=f"Req {req_idx} pos {pos}: K cache corrupted by padding", + ) + torch.testing.assert_close( + v_entry.float(), + v_snap.float(), + rtol=0, + atol=0, + msg=f"Req {req_idx} pos {pos}: V cache corrupted by padding", + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])