Skip to content

Commit

Permalink
Drop RoPE when filling KV cache (#3346)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#488

This PR provides CUDA kernels to fill KV cache without applying ROPE.


Reviewed By: jianyuh

Differential Revision: D66307820

Pulled By: GD06
  • Loading branch information
GD06 authored and facebook-github-bot committed Nov 21, 2024
1 parent 8993811 commit 5cbab5a
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 0 deletions.
32 changes: 32 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ namespace fbgemm_gpu {
#define STRING_(s) #s
#define STRING(x) STRING_(x)

at::Tensor nope_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor varseq_batch,
at::Tensor varseq_seqpos,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> varseq_cache_seqpos);

at::Tensor nope_qkv_decoding(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor seqpos,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> actual_batch_size,
std::optional<at::Tensor> batch,
std::optional<at::Tensor> cache_seqpos);

at::Tensor rope_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
Expand Down Expand Up @@ -153,6 +178,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("rope_qkv_decoding", rope_qkv_decoding);
m.def(
"nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None) -> Tensor");
m.impl("nope_qkv_varseq_prefill", nope_qkv_varseq_prefill);
m.def("nope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None) -> Tensor");
m.impl("nope_qkv_decoding", nope_qkv_decoding);
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill);
Expand Down
232 changes: 232 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,112 @@ enum class QKV { Q, K, V };
DEVICE_INLINE void
quantize_fp8_kv(fx4 dst, uint8_t* dst_row_q, __half2* qparam = nullptr);

__global__ void nope_qkv_varseq_prefill_kernel(
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
XQ, // [B_T][N_H][D_H]
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
XK, // [B_T][N_KVH][D_H]
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
XV, // [B_T][N_KVH][D_H]
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
cache_K, // [B][MAX_T][N_KVH][D_H] or
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
at::PackedTensorAccessor64<at::BFloat16, 4, at::RestrictPtrTraits>
cache_V, // [B][MAX_T][N_KVH][D_H] or
// [1][MAX_PAGES * PAGE_SIZE][N_KVH][D_H] for paged attention
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
XQ_O, // [B_T][N_H][D]
int32_t* varseq_batch, // in decoding case we have T == 1 and so just pass
// nullptr
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> varseq_seqpos,
int32_t* block_tables, // [B][MAX_PAGES], maps logical pages to physical
// ones for paged attention
int32_t page_size,
int32_t block_tables_b_stride,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
varseq_cache_seqpos,
int64_t* actual_batch_size =
nullptr // When running in CUDA graph mode, the actual batch size
// can be smaller than block_tables.size(0). In this case
// rows of block_tables beyond actual_batch_size are not
// initialized, and using them wil cause undefined
// behavior. To prevent this, when actual_batch_size is
// provided, the kernel exits if the current batch index is
// larger of equal to actual_batch_size,
) {
// Launch b_t_(sum(h)) warps.
auto b_t_hh = blockIdx.x * blockDim.y + threadIdx.y;
auto B_T = XQ.size(0);
auto N_KVH = XK.size(1);
auto N_H = XQ.size(1);
auto D_H = XQ.size(2);
auto HH = 2 * N_KVH + N_H;

auto hh = b_t_hh % HH;
auto b_t = b_t_hh / HH;
if (b_t >= B_T) {
return;
}
auto seqpos_t = varseq_seqpos[b_t];
if (seqpos_t == -1) {
return;
}
auto cache_loc_t = varseq_cache_seqpos[b_t];
auto b = varseq_batch ? varseq_batch[b_t] : b_t;

if (actual_batch_size != nullptr && b_t >= *actual_batch_size) {
return;
}

at::BFloat16* src_row;
at::BFloat16* dst_row;
auto h = 0;
if (hh < N_H) {
h = hh;
src_row = &XQ[b_t][h][0];
dst_row = &XQ_O[b_t][h][0];
} else if (hh < N_H + N_KVH) {
h = hh - N_H;
src_row = &XK[b_t][h][0];

get_dst_row(
&dst_row,
cache_K,
b,
h,
cache_loc_t,
page_size,
block_tables,
block_tables_b_stride);
} else {
h = hh - N_H - N_KVH;
src_row = &XV[b_t][h][0];
get_dst_row(
&dst_row,
cache_V,
b,
h,
cache_loc_t,
page_size,
block_tables,
block_tables_b_stride);
}

for (int32_t head_id = 4 * threadIdx.x; head_id < D_H;
head_id += kThreadsPerWarp * 4) {
// assert D_H % 4 == 0;
// load 4 elements per thread in a warp.
if (head_id >= D_H) {
return;
}
bfx4 src;
*reinterpret_cast<uint2*>(&src) =
*reinterpret_cast<uint2*>(&src_row[head_id]);
*reinterpret_cast<uint2*>(&dst_row[head_id]) =
*reinterpret_cast<uint2*>(&src);
}
}

template <PositionEmbeddingMode Mode>
__global__ void rope_xpos_qkv_varseq_prefill_kernel(
at::PackedTensorAccessor32<at::BFloat16, 3, at::RestrictPtrTraits>
Expand Down Expand Up @@ -827,6 +933,132 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_(
}
}

at::Tensor nope_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor varseq_batch,
at::Tensor varseq_seqpos,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> varseq_cache_seqpos) {
auto B_T = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = XK.size(1);

TORCH_CHECK(XQ.size(2) % 4 == 0);
TORCH_CHECK(XQ.size(2) <= 512);

int32_t num_warps = B_T * (2 * N_KVH + N_H);
TORCH_CHECK(num_warps > 0);

dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock));

TORCH_CHECK(varseq_batch.is_contiguous());
TORCH_CHECK(varseq_batch.numel() == B_T);
auto XQ_O = at::empty_like(XQ);

auto varseq_cache_seqpos_ = varseq_cache_seqpos.value_or(varseq_seqpos);

int32_t* block_tables_ptr = nullptr;
int32_t block_tables_b_stride = 0;
if (block_tables.has_value()) {
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
block_tables_b_stride = block_tables.value().stride(0);
}

// Current NOPE kernel only supports BF16
TORCH_CHECK(cache_K.dtype() == at::kBFloat16);

nope_qkv_varseq_prefill_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
varseq_batch.data_ptr<int32_t>(),
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
block_tables_ptr,
page_size,
block_tables_b_stride,
varseq_cache_seqpos_
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return XQ_O;
}
at::Tensor nope_qkv_decoding(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor seqpos,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> actual_batch_size,
std::optional<at::Tensor> batch,
std::optional<at::Tensor> cache_seqpos) {
auto B = XQ.size(0);
auto N_H = XQ.size(1);
auto N_KVH = XK.size(1);
TORCH_CHECK(XQ.size(2) % 4 == 0);
int32_t num_warps = B * (2 * N_KVH + N_H);
TORCH_CHECK(num_warps > 0);
dim3 threads(kThreadsPerWarp, kWarpsPerBlock);
dim3 blocks(cuda_calc_xblock_count(num_warps, kWarpsPerBlock));
auto XQ_O = at::empty_like(XQ);
int32_t* block_tables_ptr = nullptr;
int32_t block_tables_b_stride = 0;
if (block_tables.has_value()) {
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
block_tables_b_stride = block_tables.value().stride(0);
}
int64_t* actual_batch_size_ptr = nullptr;
if (actual_batch_size.has_value()) {
actual_batch_size_ptr =
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
}
auto cache_seqpos_ = cache_seqpos.value_or(seqpos);
// Current NOPE kernel only supports BF16
TORCH_CHECK(cache_K.dtype() == at::kBFloat16);
nope_qkv_varseq_prefill_kernel<<<
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream()>>>(
XQ.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
XK.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
XV.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
cache_K.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
cache_V.packed_accessor64<at::BFloat16, 4, at::RestrictPtrTraits>(),
XQ_O.packed_accessor32<at::BFloat16, 3, at::RestrictPtrTraits>(),
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr,
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
block_tables_ptr,
page_size,
block_tables_b_stride,
cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
actual_batch_size_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return XQ_O;
}
at::Tensor rope_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
Expand Down

0 comments on commit 5cbab5a

Please sign in to comment.