Skip to content
Merged
73 changes: 42 additions & 31 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <curand_mtgp32_kernel.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
Expand Down Expand Up @@ -209,6 +210,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,

namespace vllm {

// Used to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;

__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
Expand All @@ -232,51 +247,47 @@ __global__ void reshape_and_cache_kernel(
const int64_t block_offset = slot_idx % block_size;

const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};

for (int i = threadIdx.x; i < (n / x); i += blockDim.x) {
const int head_idx = (i * x) / head_size;
const int head_offset = (i * x) % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
block_offset * x;

cache_t* __restrict__ key_dst = key_cache + tgt_key_idx;
const scalar_t* __restrict__ key_src = key + token_idx * key_stride + i * x;

#pragma unroll
for (int j = 0; j < x; j++) {
scalar_t src = key_src[j];
k_op(key_dst[j], src);
}
}

for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;

const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
v_op(value_cache[tgt_value_idx], tgt_value);
}
}

// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
float scale;

__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst = static_cast<OutT>(src);
} else {
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
}
}
};

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
Expand Down