diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..0dd9a2b0c5a5 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "csrc/flash_attn_v2/third_party/cutlass"] + path = csrc/flash_attn_v2/third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/csrc/flash_attn_v2/paged_flash/LICENSE b/csrc/flash_attn_v2/paged_flash/LICENSE new file mode 100644 index 000000000000..5860e4b33f3d --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/csrc/flash_attn_v2/paged_flash/Makefile b/csrc/flash_attn_v2/paged_flash/Makefile new file mode 100644 index 000000000000..ab58a6bd2f10 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/Makefile @@ -0,0 +1,181 @@ +# CMAKE generated file: DO NOT EDIT! +# Generated by "Unix Makefiles" Generator, CMake Version 3.25 + +# Default target executed when no arguments are given to make. +default_target: all +.PHONY : default_target + +# Allow only one "make -f Makefile2" at a time, but pass parallelism. +.NOTPARALLEL: + +#============================================================================= +# Special targets provided by cmake. + +# Disable implicit rules so canonical targets will work. +.SUFFIXES: + +# Disable VCS-based implicit rules. +% : %,v + +# Disable VCS-based implicit rules. +% : RCS/% + +# Disable VCS-based implicit rules. +% : RCS/%,v + +# Disable VCS-based implicit rules. +% : SCCS/s.% + +# Disable VCS-based implicit rules. +% : s.% + +.SUFFIXES: .hpux_make_needs_suffix_list + +# Command-line flag to silence nested $(MAKE). +$(VERBOSE)MAKESILENT = -s + +#Suppress display of executed commands. +$(VERBOSE).SILENT: + +# A target that is always out of date. +cmake_force: +.PHONY : cmake_force + +#============================================================================= +# Set environment variables for the build. + +# The shell in which to execute make rules. +SHELL = /bin/sh + +# The CMake executable. +CMAKE_COMMAND = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake + +# The command to remove a file. +RM = /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E rm -f + +# Escaping for special characters. +EQUALS = = + +# The top-level source directory on which CMake was run. +CMAKE_SOURCE_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash + +# The top-level build directory on which CMake was run. +CMAKE_BINARY_DIR = /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash + +#============================================================================= +# Targets provided globally by CMake. + +# Special rule for the target edit_cache +edit_cache: + @$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "No interactive CMake dialog available..." + /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake -E echo No\ interactive\ CMake\ dialog\ available. +.PHONY : edit_cache + +# Special rule for the target edit_cache +edit_cache/fast: edit_cache +.PHONY : edit_cache/fast + +# Special rule for the target rebuild_cache +rebuild_cache: + @$(CMAKE_COMMAND) -E cmake_echo_color --switch=$(COLOR) --cyan "Running CMake to regenerate build system..." + /home/deepspeed/.local/lib/python3.8/site-packages/cmake/data/bin/cmake --regenerate-during-build -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) +.PHONY : rebuild_cache + +# Special rule for the target rebuild_cache +rebuild_cache/fast: rebuild_cache +.PHONY : rebuild_cache/fast + +# The main all target +all: cmake_check_build_system + $(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash//CMakeFiles/progress.marks + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 all + $(CMAKE_COMMAND) -E cmake_progress_start /data/private_dev/DeepSpeed-Kernels/inf_flash_attn/blocked_flash/CMakeFiles 0 +.PHONY : all + +# The main clean target +clean: + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 clean +.PHONY : clean + +# The main clean target +clean/fast: clean +.PHONY : clean/fast + +# Prepare targets for installation. +preinstall: all + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall +.PHONY : preinstall + +# Prepare targets for installation. +preinstall/fast: + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 preinstall +.PHONY : preinstall/fast + +# clear depends +depend: + $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 1 +.PHONY : depend + +#============================================================================= +# Target rules for targets named gemm + +# Build rule for target. +gemm: cmake_check_build_system + $(MAKE) $(MAKESILENT) -f CMakeFiles/Makefile2 gemm +.PHONY : gemm + +# fast build rule for target. +gemm/fast: + $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/build +.PHONY : gemm/fast + +flash_fwd_hdim32_bf16_sm80.o: flash_fwd_hdim32_bf16_sm80.cu.o +.PHONY : flash_fwd_hdim32_bf16_sm80.o + +# target to build an object file +flash_fwd_hdim32_bf16_sm80.cu.o: + $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.o +.PHONY : flash_fwd_hdim32_bf16_sm80.cu.o + +flash_fwd_hdim32_bf16_sm80.i: flash_fwd_hdim32_bf16_sm80.cu.i +.PHONY : flash_fwd_hdim32_bf16_sm80.i + +# target to preprocess a source file +flash_fwd_hdim32_bf16_sm80.cu.i: + $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.i +.PHONY : flash_fwd_hdim32_bf16_sm80.cu.i + +flash_fwd_hdim32_bf16_sm80.s: flash_fwd_hdim32_bf16_sm80.cu.s +.PHONY : flash_fwd_hdim32_bf16_sm80.s + +# target to generate assembly for a file +flash_fwd_hdim32_bf16_sm80.cu.s: + $(MAKE) $(MAKESILENT) -f CMakeFiles/gemm.dir/build.make CMakeFiles/gemm.dir/flash_fwd_hdim32_bf16_sm80.cu.s +.PHONY : flash_fwd_hdim32_bf16_sm80.cu.s + +# Help Target +help: + @echo "The following are some of the valid targets for this Makefile:" + @echo "... all (the default if no target is provided)" + @echo "... clean" + @echo "... depend" + @echo "... edit_cache" + @echo "... rebuild_cache" + @echo "... gemm" + @echo "... flash_fwd_hdim32_bf16_sm80.o" + @echo "... flash_fwd_hdim32_bf16_sm80.i" + @echo "... flash_fwd_hdim32_bf16_sm80.s" +.PHONY : help + + + +#============================================================================= +# Special targets to cleanup operation of make. + +# Special rule to run CMake to check the build system integrity. +# No rule that depends on this can have commands that come from listfiles +# because they might be regenerated. +cmake_check_build_system: + $(CMAKE_COMMAND) -S$(CMAKE_SOURCE_DIR) -B$(CMAKE_BINARY_DIR) --check-build-system CMakeFiles/Makefile.cmake 0 +.PHONY : cmake_check_build_system + diff --git a/csrc/flash_attn_v2/paged_flash/attention_atom.h b/csrc/flash_attn_v2/paged_flash/attention_atom.h new file mode 100644 index 000000000000..4263433d79f0 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/attention_atom.h @@ -0,0 +1,27 @@ + +#pragma once + +#include +#include "cuda.h" +#include "cute/pointer.hpp" + +struct __align__(32) AttentionAtom { + using index_t = uint32_t; + + index_t* block_idx_list; + + index_t q_start_idx; + index_t q_len; + index_t kv_blocks; + index_t total_extent; + index_t global_q_idx; + index_t unused; + + template + __device__ void load_kv_block_idxs(cute::smem_ptr block_idx_list_shr, int tidx) const + { + for (int i = tidx; i < kv_blocks; i += threads) { block_idx_list_shr[i] = block_idx_list[i]; } + // Aggressive (but safe) sync + __syncthreads(); + } +}; diff --git a/csrc/flash_attn_v2/paged_flash/flash.h b/csrc/flash_attn_v2/paged_flash/flash.h new file mode 100644 index 000000000000..85582b92e141 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash.h @@ -0,0 +1,83 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "attention_atom.h" + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + + // The attention metadata + // AttentionAtom* __restrict__ atoms; + + // Total attention atoms + // int num_atoms; + + // PagedAttention metadata + int num_seqs; + int max_num_query; + int max_context_len; + int block_size; + int max_num_blocks_per_seq; + + index_t* __restrict__ block_tables; + index_t* __restrict__ context_lens; + index_t* __restrict__ draft_lens; + + // The stride between rows of O. + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions + int d, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); diff --git a/csrc/flash_attn_v2/paged_flash/flash_api.cu b/csrc/flash_attn_v2/paged_flash/flash_api.cu new file mode 100644 index 000000000000..4ff348f5a76f --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_api.cu @@ -0,0 +1,99 @@ +#include "flash.h" +#include "static_switch.h" +#include +#include +#include +#include "kernel_traits.h" +#include "flash_fwd_launch_template.h" + +// void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) +// { +// FP16_SWITCH(!params.is_bf16, [&] { +// FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_(params, stream); }); +// }); +// } + +#define M_LOG2E 1.4426950408889634074 // log_2 e + +// for now, assume that +// head_dim = 64 +// block_size = 16 +// num_heads = 12 +void paged_flash_attention( + torch::Tensor& out, // [num_seqs, max_num_query, num_heads, head_size] + torch::Tensor& query, // [num_seqs, max_num_query, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& draft_lens, // [num_seqs] + int block_size, + int max_context_len, + int max_num_query) { + size_t num_heads = query.size(2); + size_t head_size = query.size(3); + size_t num_seqs = query.size(0); + size_t max_num_blocks_per_seq = DIVIDE_ROUND_UP(max_context_len, block_size); + + TORCH_CHECK(num_heads == 12, "only 12 heads are supported"); + TORCH_CHECK(head_size == 64, "only head size of 64 is supported"); + TORCH_CHECK(block_size == 16, "only block size of 16 is supported"); + TORCH_CHECK(num_kv_heads == num_heads, "MQA is not supported"); + TORCH_CHECK(query.dtype() == at::ScalarType::Half, "only half is supported"); + + // create params + Flash_fwd_params params; + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = reinterpret_cast(query.data_ptr()); + params.k_ptr = reinterpret_cast(key_cache.data_ptr()); + params.v_ptr = reinterpret_cast(value_cache.data_ptr()); + + // Calculate batch_stride using cu_seq + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_kv_heads * head_size; + params.v_row_stride = num_kv_heads * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + + params.h = num_heads; + params.h_k = num_kv_heads; + params.h_h_k_ratio = params.h / params.h_k; + + params.o_ptr = reinterpret_cast(out.data_ptr()); + + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + + // Set the dimensions. + params.d = params.d_rounded = head_size; + + params.scale_softmax = 1.0 / std::sqrt(head_size); + params.scale_softmax_log2 = params.scale_softmax * M_LOG2E; + + params.is_bf16 = false; + params.is_causal = true; + + params.num_seqs = num_seqs; + params.max_num_query = max_num_query; + params.max_context_len = max_context_len; + params.block_size = block_size; + params.max_num_blocks_per_seq = max_num_blocks_per_seq; + + params.block_tables = reinterpret_cast(block_tables.data_ptr()); + params.context_lens = reinterpret_cast(context_lens.data_ptr()); + params.draft_lens = reinterpret_cast(draft_lens.data_ptr()); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + run_flash_fwd, /*Is_causal=*/true>( + params, stream); + + return; +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000000..878935480e07 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, +// false>(params, stream); +// } else { +// run_flash_fwd, +// true>(params, stream); +// } +// } +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim128(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000000..f3b9b12f529d --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,39 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k +// run_flash_fwd, +// false>(params, stream); +// // run_flash_fwd, +// false>(params, stream); +// // run_flash_fwd, +// false>(params, stream); +// // run_flash_fwd, +// false>(params, stream); run_flash_fwd, false>(params, stream); run_flash_fwd, false>(params, stream); +// run_flash_fwd, +// false>(params, stream); +// // 1st ones are good for H100, A100 +// // 2nd one is good for A6000 bc we get slightly better occupancy +// } else { +// run_flash_fwd, +// true>(params, stream); run_flash_fwd, true>(params, stream); run_flash_fwd, true>(params, stream); +// // 1st one is good for H100, A100, A6000 +// } +// } + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim128(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000000..60a241e3f905 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000000..d6c45c1d525c --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,11 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000000..f9bab5f26644 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd, +// false>(params, stream); +// } else { +// run_flash_fwd, +// true>(params, stream); +// } +// } +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim64(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000000..098b5284fd94 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower +// // Using block size (64 x 256) is 27% slower for seqlen=2k +// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling +// run_flash_fwd, +// false>(params, stream); run_flash_fwd, false>(params, stream); run_flash_fwd, false>(params, stream); +// } else { +// run_flash_fwd, +// true>(params, stream); run_flash_fwd, true>(params, stream); run_flash_fwd, true>(params, stream); +// } +// } +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim64(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000000..df1543edadba --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, +// Is_dropout>(params, stream); +// }); +// } +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim96(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000000..90ecacf140a6 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd, +// Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); +// // This 3rd one is good for H100, and A100, A6000 +// run_flash_fwd, +// Is_dropout>(params, stream); run_flash_fwd, Is_dropout>(params, stream); +// // These two are always slower +// // run_flash_fwd>(params, +// stream); +// // run_flash_fwd>(params, +// stream); +// }); +// } +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) +{ + run_mha_fwd_hdim96(params, stream); +} diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_kernel.h b/csrc/flash_attn_v2/paged_flash/flash_fwd_kernel.h new file mode 100644 index 000000000000..624bbdec0db4 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_kernel.h @@ -0,0 +1,476 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include "attention_atom.h" +#include "kernel_traits.h" +#include "softmax.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(Layout, Int>, + Stride<_1, Int> >{}, + // TODO: Shouldn't this be size<1>? + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, + const AttentionAtom& atom_info, + const int head_idx) +{ + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = + kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + // May have a padded launch of attention atoms, this will handle that for us. + if (atom_info.q_len == 0) return; + + // stream the kv block idxs into shared memory + smem_ptr block_idx_list = make_smem_ptr(smem_ + Kernel_traits::kSmemFlashSize); + atom_info.load_kv_block_idxs(block_idx_list, tidx); + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = + atom_info.q_start_idx * params.q_row_stride + head_idx * params.q_head_stride; + + // We move K and V to the last block. + const index_t row_offset_k = + block_idx_list[atom_info.kv_blocks - 1] * kBlockN * params.k_row_stride + + (head_idx / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = + block_idx_list[atom_info.kv_blocks - 1] * kBlockN * params.v_row_stride + + (head_idx / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = + make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = + partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor( + make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor( + make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, atom_info.q_len); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = atom_info.kv_blocks - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, atom_info.total_extent - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + constexpr int n_masking_steps = !Is_causal + ? 1 + : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C( + tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (block_idx_list[n_block] - block_idx_list[n_block + 1]) * + int(kBlockN * params.v_row_stride); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, atom_info.total_extent - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + flash::apply_mask(scores, atom_info.total_extent); + } else { + flash::apply_mask_causal(scores, + n_block * kBlockN, + atom_info.total_extent, + atom_info.global_q_idx + (tidx / 32) * 16 + (tidx % 32) / 4, + atom_info.total_extent, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (block_idx_list[n_block - 1] - block_idx_list[n_block]) * + int(kBlockN * params.k_row_stride); + flash::copy( + gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 ? softmax_rescale_o( + scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o( + scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + Tensor acc_s = partition_fragment_C( + tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (block_idx_list[n_block] - block_idx_list[n_block + 1]) * + int(kBlockN * params.v_row_stride); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (block_idx_list[n_block - 1] - block_idx_list[n_block]) * + int(kBlockN * params.k_row_stride); + flash::copy( + gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o( + scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor( + rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename + // Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = + atom_info.q_start_idx * params.o_row_stride + head_idx * params.o_head_stride; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + // Construct identity layout for sO + Tensor cO = make_identity_tensor( + make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, atom_info.q_len); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) +{ + const int atom_idx = blockIdx.y; + const int head_idx = blockIdx.x; + + if (atom_idx >= params.num_seqs) return; + + AttentionAtom atom; + + int draft_len = params.draft_lens[atom_idx]; + int context_len = params.context_lens[atom_idx]; + + atom.block_idx_list = reinterpret_cast(params.block_tables) + atom_idx * params.max_num_blocks_per_seq; + atom.q_start_idx = atom_idx * params.max_num_query; + atom.q_len = draft_len; + atom.kv_blocks = DIVIDE_ROUND_UP(context_len, params.block_size); + atom.total_extent = context_len; + atom.global_q_idx = context_len - draft_len; + + flash::compute_attn_1rowblock(params, atom, head_idx); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn_v2/paged_flash/flash_fwd_launch_template.h b/csrc/flash_attn_v2/paged_flash/flash_fwd_launch_template.h new file mode 100644 index 000000000000..9644b4d22b6d --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/flash_fwd_launch_template.h @@ -0,0 +1,154 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "flash.h" +#include "flash_fwd_kernel.h" +#include "static_switch.h" +#include "kernel_traits.h" + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) +{ + flash::compute_attn(params); +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) +{ + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_block_atoms = params.num_seqs; + const int num_block_heads = params.h; + + dim3 grid(num_block_heads, num_block_atoms); + + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_fwd_kernel; + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) +{ + constexpr int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>( + params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) +{ + constexpr int Headdim = 96; + int cc_major, cc_minor; + cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, 0); + cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, 0); + bool is_sm8x = cc_major == 8 && cc_minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, + Is_causal>(params, stream); + } else { + run_flash_fwd, + Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>( + params, stream); + } + // run_flash_fwd, Is_dropout, + // Is_causal>(params, stream); run_flash_fwd, Is_dropout, Is_causal>(params, stream); These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) +{ + constexpr int Headdim = 128; + int cc_major, cc_minor; + cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, 0); + cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, 0); + bool is_sm8x = cc_major == 8 && cc_minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, + Is_causal>(params, stream); + } else { + run_flash_fwd, + Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>( + params, stream); + } + // run_flash_fwd, Is_dropout, + // Is_causal>(params, stream); run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, + // Is_causal>(params, stream); Using 8 warps (128 x 128 and 256 x 64) is 28% slower for + // seqlen=2k run_flash_fwd, + // Is_dropout, Is_causal>(params, stream); run_flash_fwd, Is_dropout, Is_causal>(params, stream); 1st ones are good + // for H100, A100 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) +{ + constexpr int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, + // max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && + max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>( + params, stream); + } else { + run_flash_fwd, Is_causal>( + params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, + // Is_causal>(params, stream); 96 KB run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); +} diff --git a/csrc/flash_attn_v2/paged_flash/kernel_traits.h b/csrc/flash_attn_v2/paged_flash/kernel_traits.h new file mode 100644 index 000000000000..d65880b3a86e --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/kernel_traits.h @@ -0,0 +1,175 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" + +namespace flash { + +using namespace cute; + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for + // 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition( + Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, Stride, _1>>{})); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape, Int>{})); + + using SmemLayoutKV = + decltype(tile_to_shape(SmemLayoutAtomQ{}, Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtomO{}, Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kMaxBlocks = 256; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemBlockSize = kMaxBlocks * sizeof(int32_t); + static constexpr int kSmemFlashSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) + : kSmemQSize + kSmemKVSize; + static constexpr int kSmemSize = kSmemFlashSize + kSmemBlockSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, + "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, + "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = + Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = + std::conditional_t, DefaultCopy>; + using GmemTiledCopyQKV = + decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = + decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, + "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = + Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = + decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn_v2/paged_flash/softmax.h b/csrc/flash_attn_v2/paged_flash/softmax.h new file mode 100644 index 000000000000..7b7d461a753a --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/softmax.h @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, + Tensor& max, + Tensor& sum, + const float scale) +{ + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx(Tensor& tensor, + Tensor const& idx_rowcol, + const int32_t col_idx_offset_, + const int32_t max_seqlen_k, + const int32_t row_idx_offset_) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int32_t col_idx_limit = + std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, + // max_seqlen_k); print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash diff --git a/csrc/flash_attn_v2/paged_flash/static_switch.h b/csrc/flash_attn_v2/paged_flash/static_switch.h new file mode 100644 index 000000000000..9f6f1b3a1e94 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/static_switch.h @@ -0,0 +1,54 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/flash_attn_v2/paged_flash/utils.h b/csrc/flash_attn_v2/paged_flash/utils.h new file mode 100644 index 000000000000..6e656c1f6fd2 --- /dev/null +++ b/csrc/flash_attn_v2/paged_flash/utils.h @@ -0,0 +1,223 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) + { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) + { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + // return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); + // return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + // get<1>(get<0>(l)), + // get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn_v2/third_party/cutlass b/csrc/flash_attn_v2/third_party/cutlass new file mode 160000 index 000000000000..6f47420213f7 --- /dev/null +++ b/csrc/flash_attn_v2/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit 6f47420213f757831fae65c686aa471749fa8d60 diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da141..3b5b734614a3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -31,6 +31,20 @@ void paged_attention_v2( int max_context_len, const c10::optional& alibi_slopes); +void paged_flash_attention( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + torch::Tensor& draft_lens, + int block_size, + int max_context_len, + int max_num_query); + void rms_norm( torch::Tensor& out, torch::Tensor& input, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f33..e709e27ed1c3 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -16,6 +16,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); + ops.def( + "paged_flash_attention", + &paged_flash_attention, + "compute attention between input queries and the cached keys/values using Flash Attention v2."); // Activation ops ops.def( diff --git a/setup.py b/setup.py index 811d494e7a01..9a09c1ec8ab2 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,9 @@ ROOT_DIR = os.path.dirname(__file__) +def get_path(*filepath) -> str: + return os.path.join(ROOT_DIR, *filepath) + MAIN_CUDA_VERSION = "12.1" # Supported NVIDIA GPU architectures. @@ -48,6 +51,10 @@ def _is_cuda() -> bool: CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +cutlass_include_path = get_path("csrc/flash_attn_v2/third_party/cutlass/include") +if (not os.path.exists(cutlass_include_path)) and (not ROOT_DIR == ""): + raise RuntimeError(f"Cannot find {cutlass_include_path}. Did you run git submodule update?") +NVCC_FLAGS += [f"-I{cutlass_include_path}"] def get_amdgpu_offload_arch(): command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" @@ -85,7 +92,7 @@ def get_hipcc_rocm_version(): else: print("Could not find HIP version in the output") return None - + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -221,6 +228,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", + "csrc/flash_attn_v2/paged_flash/flash_api.cu", "csrc/pybind.cpp", ] @@ -238,10 +246,6 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(vllm_extension) -def get_path(*filepath) -> str: - return os.path.join(ROOT_DIR, *filepath) - - def find_version(filepath: str) -> str: """Extract version information from the given filepath. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 814d40f56def..795c723161ca 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,11 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm._C import ops +from vllm import SamplingParams +from vllm.sequence import SequenceData from vllm.utils import get_max_shared_memory_bytes +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.input_metadata import InputMetadata FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. @@ -18,6 +22,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_QUERY = [5] NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -233,6 +238,177 @@ def test_paged_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +def ref_multi_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + num_query: torch.Tensor, + scale: float, +) -> None: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + query_lens = num_query + for i in range(num_seqs): + block_table = block_tables[i] + context_len = int(context_lens[i]) + + for query_num in range(query_lens[i]): + q = query[i, query_num].reshape(1, num_heads, head_size) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + for j in range(query_num): + k = key[i, j] + k = k.reshape(num_heads, head_size) + keys.append(k) + + v = value[i, j] + v = v.reshape(num_heads, head_size) + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + out = ref_masked_attention(q, keys, values, scale) + out = out.view(-1) + output[i, query_num].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("num_seqs", [2]) +@pytest.mark.parametrize("max_num_query", [8]) +@pytest.mark.parametrize("num_heads", [(12, )]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", [torch.half]) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_multi_query_cached_kv_attention( + kv_cache_factory, + num_seqs: int, + max_num_query: int, + num_heads: Tuple[int, int], + head_size: int, + block_size: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + num_heads = num_heads[0] + + scale = float(1.0 / (head_size**0.5)) + + qkv = torch.empty((num_seqs, max_num_query, 3 * num_heads * head_size), + dtype=dtype, + device="cuda") + qkv.uniform_(-scale, scale) + + # maximum number of draft tokens are included despite not necessarily needing. + # Slice in last dimension as this is expected format of output of KQV projection. + query = qkv[:, :, :num_heads * head_size] + key = qkv[:, :, num_heads * head_size:2 * num_heads * head_size] + value = qkv[:, :, 2 * num_heads * head_size:] + + # generate random context lens + context_lens = [random.randint(1, 100) for _ in range(num_seqs)] + context_lens[-1] = 100 + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + + # generate random query lens + query_lens = [random.randint(1, max_num_query) for _ in range(num_seqs)] + query_lens[-1] = max_num_query + query_lens_tensor = torch.tensor(query_lens, dtype=torch.int, device="cuda") + + # Create the block tables, following single_query_attention test + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables_tensor = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables_tensor.append(block_table) + block_tables_tensor = torch.tensor(block_tables_tensor, + dtype=torch.int, + device="cuda") + + # create slot mapping + slot_mapping = [] + for i in range(num_seqs): + # mappings < 0 are ignored by the reshape_and_cache kernel + slot_mapping.append([-1] * max_num_query) + for j in range(query_lens[i]): + abs_position = context_lens[i] + j + logical_block_idx = abs_position // block_size + logical_block_offset = abs_position % block_size + phys_block_idx = block_tables_tensor[i][logical_block_idx] + slot_mapping[i][j] = phys_block_idx * block_size + logical_block_offset + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cuda") + + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_heads, head_size, dtype, + seed, "cuda") + key_cache, value_cache = key_caches[0], value_caches[0] + + # need input_metadata to pass in block_tables, slot_mapping + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping_tensor, + context_lens=context_lens_tensor, + max_context_len=max_context_len, + block_tables=block_tables_tensor, + use_cuda_graph=False, + draft_lens=query_lens_tensor) + + attn = PagedAttention(num_heads, head_size, scale) + output = attn.forward(query, key, value, key_cache, value_cache, + input_metadata) + assert output.shape == query.shape + + ref_output = torch.zeros_like(query) + ref_multi_query_cached_kv_attention( + ref_output, + query, + key, + value, + key_cache, + value_cache, + block_tables_tensor, + context_lens_tensor, + query_lens, + scale, + ) + + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + def ref_multi_query_kv_attention( cu_seq_lens: List[int], query: torch.Tensor, @@ -265,7 +441,6 @@ def ref_multi_query_kv_attention( ref_output = torch.cat(ref_outputs, dim=0) return ref_output - # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -332,3 +507,4 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index da615ecccf99..96fef1ca2926 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -12,6 +12,7 @@ class InputMetadata: max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) + draft_len: length of the draft to verify. Shape = [num_generation_tokens] """ def __init__( @@ -22,12 +23,14 @@ def __init__( context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + draft_lens: Optional[torch.Tensor] = None, ) -> None: self.is_prompt = is_prompt self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens self.block_tables = block_tables + self.draft_lens = draft_lens self.use_cuda_graph = use_cuda_graph # Set during the execution of the first attention op. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f1008ec8159f..3151ba00d361 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,16 +156,27 @@ def forward( output = out.view_as(query) else: # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) - + if input_metadata.draft_lens is not None: + output = _multi_query_paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + else: + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -275,3 +286,55 @@ def _paged_attention( alibi_slopes, ) return output + +def _multi_query_paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> torch.Tensor: + """PagedAttention for the generation tokens assuming multiple draft tokens. + Assumes that the key and value have already been cached. + + Args: + output: shape = [num_generation_tokens * max_num_query, num_heads, head_size] + query: shape = [num_generation_tokens * max_num_query, num_heads, head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for paged attention. + alibi_slopes: shape = [num_heads] + """ + max_num_query = max(input_metadata.draft_lens) + + block_size = value_cache.shape[3] + num_seqs = input_metadata.context_lens.shape[0] + num_heads = query.shape[1] + assert num_seqs * max_num_query == query.shape[0] + + query = query.reshape( + num_seqs, + max_num_query, + num_heads, + query.shape[2]) + output = torch.empty_like(query) + + ops.paged_flash_attention( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + input_metadata.draft_lens, + block_size, + input_metadata.max_context_len, + max_num_query) + + return output