Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/engine/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

static constexpr int64_t GB = int64_t(1024) * 1024 * 1024;

DEFINE_int32(block_size, 256, "slots per block, value must be multiple of 256");
DEFINE_int32(block_size, 16, "slots per block, value must be multiple of 16");
DEFINE_int64(max_cache_size, 10 * GB, "max cache size in bytes, default 10GB");
DEFINE_double(max_memory_utilization,
0.9,
Expand Down Expand Up @@ -294,8 +294,6 @@ bool Engine::init_kv_cache(int64_t cache_size_in_bytes) {
CHECK_GT(cache_size_in_bytes, 0);
LOG(INFO) << "Initializing kv cache with size: "
<< readable_size(cache_size_in_bytes);
CHECK(FLAGS_block_size % 256 == 0)
<< "cache block size must be divisible by 256";

const int64_t block_size = FLAGS_block_size;

Expand Down
2 changes: 1 addition & 1 deletion src/kernels/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mha_varlen_fwd(at::Tensor& out, // [n_tokens, n_heads, head_dim]
const int n_blocks = !paged_KV ? 0 : k.size(0);
const int block_size = !paged_KV ? 1 : k.size(1);
// TODO: support smaller block sizes
TORCH_CHECK(!paged_KV || block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
TORCH_CHECK(!paged_KV || block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");

const int batch_size = cu_seqlens_q.numel() - 1;
// [n_tokens, n_heads, head_dim]
Expand Down
136 changes: 76 additions & 60 deletions src/kernels/flash_attn/src/flash_fwd_kernel.h

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions src/kernels/flash_attn/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read

// from how many rows does each thread have to fetch
static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
// Here we assign a contiguous tile to each thread, rather than a 1x8 row every
// (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
// do not cross a page boundary. This way, each thread need only fetch 1 page index per
// mainloop iteration. R>udimentary testing shows no slowdown.
using GmemTiledCopyQKVPaged = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));

using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Expand All @@ -152,6 +164,14 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
};

////////////////////////////////////////////////////////////////////////////////////////////////////
47 changes: 47 additions & 0 deletions src/kernels/flash_attn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,4 +379,51 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S

////////////////////////////////////////////////////////////////////////////////////////////////////

// resolves initial base offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;

const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;

return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
+ col_offset;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_thread_tile(Layout<Shape, Stride> l) {
return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
append(get<0>(l.stride()), get<2>(l.stride())));
}

// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
auto mode_0 = filter(flatten(get<0>(l)));
return make_layout(append(mode_0.shape(), get<2>(l.shape())),
append(mode_0.stride(), get<2>(l.stride())));
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
11 changes: 4 additions & 7 deletions src/layers/attention/attention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ class AttentionDecodeTest
int64_t /*n_kv_heads*/,
int64_t /*head_dim*/,
float /*scale*/,
bool /*alibi*/,
int32_t /*num_splits*/>> {};
bool /*alibi*/>> {};

TEST_P(AttentionDecodeTest, KVCache) {
const auto& [device,
Expand All @@ -180,8 +179,7 @@ TEST_P(AttentionDecodeTest, KVCache) {
n_kv_heads,
head_dim,
scale,
alibi,
num_splits] = GetParam();
alibi] = GetParam();
// make sure kv_max_seq_len >= q_max_seq_len
if (kv_max_seq_len < q_max_seq_len) {
GTEST_SKIP() << "kv_max_seq_len < q_max_seq_len";
Expand Down Expand Up @@ -325,15 +323,14 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(torch::kCUDA),
::testing::Values(torch::kHalf, torch::kBFloat16),
::testing::Values(1, 10), // batch_size
::testing::Values(256), // block_size
::testing::Values(16, 80, 256), // block_size
::testing::Values(1, 10), // q_max_seq_len
::testing::Values(100, 200), // kv_max_seq_len
::testing::Values(6), // n_heads
::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads
::testing::Values(32, 40, 64, 128), // head_dim
::testing::Values(0.9, 1.0), // scale
::testing::Values(false, true), // alibi
::testing::Values(1) // num_splits
::testing::Values(false, true) // alibi
));

} // namespace llm
2 changes: 1 addition & 1 deletion src/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ std::vector<int32_t> Sequence::kv_cache_slots(int32_t pos_start,
}

void Sequence::commit_kv_cache(size_t size) {
CHECK(kv_cache_pos_ + size < kv_cache_capacity());
CHECK(kv_cache_pos_ + size <= kv_cache_capacity());
kv_cache_pos_ += size;
}

Expand Down