@@ -1581,6 +1581,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
15811581 }
15821582}
15831583
1584+ // clang-format off
15841585template <typename scalar_t , typename cache_t ,
15851586 vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
15861587 int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
@@ -1594,6 +1595,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
15941595 const int num_kv_heads, const float scale,
15951596 const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
15961597 const int * __restrict__ context_lens, // [num_seqs]
1598+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
15971599 const int max_num_blocks_per_seq,
15981600 const float * __restrict__ alibi_slopes, // [num_heads]
15991601 const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -1604,6 +1606,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16041606 // head_size]
16051607 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
16061608 int max_ctx_blocks, const float * k_scale, const float * v_scale) {
1609+ // clang-format on
16071610 constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
16081611 const int warpid = threadIdx .x / WARP_SIZE;
16091612 const int laneid = threadIdx .x % WARP_SIZE;
@@ -1613,6 +1616,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16131616 const int rowid = laneid / 16 ;
16141617
16151618 const int seq_idx = blockIdx .x ;
1619+ // NOTE queries with sequence len > 1 are prefills and taken care by another
1620+ // kernel.
1621+ if (query_start_loc_ptr != nullptr &&
1622+ (query_start_loc_ptr[seq_idx + 1 ] - query_start_loc_ptr[seq_idx]) != 1 ) {
1623+ return ;
1624+ }
1625+
16161626 const int partition_idx = blockIdx .y ;
16171627
16181628 constexpr int T_PAR_SIZE = 256 ; // token partition size set to 256
@@ -1671,12 +1681,14 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16711681 // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
16721682 // across 2 rows x 8 tokens per lane
16731683
1684+ const int64_t query_start_off = static_cast <int64_t >(
1685+ query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
1686+
16741687 if (GQA_RATIO == 1 ) {
16751688 const int local_qhead_idx = lane16id % GQA_RATIO;
16761689 const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
1677- const int64_t seq_idx64 = static_cast <int64_t >(seq_idx);
16781690 const scalar_t * q_ptr =
1679- q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
1691+ q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
16801692 if (lane16id < GQA_RATIO) {
16811693 #pragma unroll
16821694 for (int qkhe_depth = 0 ; qkhe_depth < QKHELOOP / 2 ; qkhe_depth++) {
@@ -1690,9 +1702,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
16901702 // fetch Q in shared across warps and then write to registers
16911703 const int local_qhead_idx = 2 * warpid + rowid;
16921704 const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
1693- const int64_t seq_idx64 = static_cast <int64_t >(seq_idx);
16941705 const scalar_t * q_ptr =
1695- q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
1706+ q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
16961707
16971708 const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
16981709 if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
@@ -2024,6 +2035,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
20242035 const int num_kv_heads, const float scale,
20252036 const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
20262037 const int * __restrict__ context_lens, // [num_seqs]
2038+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
20272039 const int max_num_blocks_per_seq,
20282040 const float * __restrict__ alibi_slopes, // [num_heads]
20292041 const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2050,15 +2062,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
20502062 const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads,
20512063 // max_num_partitions, head_size]
20522064 const int * __restrict__ context_lens, // [num_seqs]
2065+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
20532066 const int max_num_partitions) {
2054- const int num_heads = gridDim .x ;
2055- const int head_idx = blockIdx .x ;
2056- const int seq_idx = blockIdx .y ;
2067+ const auto num_heads = gridDim .x ;
2068+ const auto head_idx = blockIdx .x ;
2069+ const auto seq_idx = blockIdx .y ;
2070+
2071+ // NOTE queries with sequence len > 1 are prefills and taken care by another
2072+ // kernel.
2073+ if (query_start_loc_ptr != nullptr &&
2074+ (query_start_loc_ptr[seq_idx + 1 ] - query_start_loc_ptr[seq_idx] != 1 )) {
2075+ return ;
2076+ }
2077+
20572078 const int context_len = context_lens[seq_idx];
20582079 const int num_partitions = DIVIDE_ROUND_UP (context_len, PARTITION_SIZE);
2059- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
2080+ [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
20602081 const int warpid = threadIdx .x / WARP_SIZE;
2061- const int laneid = threadIdx .x % WARP_SIZE;
2082+ [[maybe_unused]] const int laneid = threadIdx .x % WARP_SIZE;
20622083
20632084 __shared__ float shared_global_exp_sum;
20642085 // max num partitions supported is warp_size * NPAR_LOOPS
@@ -2221,7 +2242,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
22212242 const float inv_global_exp_sum =
22222243 __fdividef (1 .0f , shared_global_exp_sum + 1e-6f );
22232244 acc *= inv_global_exp_sum;
2224- OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
2245+
2246+ const int64_t query_start_off = static_cast <int64_t >(
2247+ query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2248+ OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
2249+ static_cast <int64_t >(head_idx) * HEAD_SIZE;
22252250 out_ptr[threadIdx .x ] = from_float<scalar_t >(acc);
22262251}
22272252
@@ -2328,6 +2353,7 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
23282353 }
23292354}
23302355
2356+ // clang-format off
23312357template <typename scalar_t , typename cache_t ,
23322358 vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
23332359 int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
@@ -2341,6 +2367,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23412367 const int num_kv_heads, const float scale,
23422368 const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
23432369 const int * __restrict__ context_lens, // [num_seqs]
2370+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
23442371 const int max_num_blocks_per_seq,
23452372 const float * __restrict__ alibi_slopes, // [num_heads]
23462373 const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2351,6 +2378,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23512378 // head_size]
23522379 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
23532380 int max_ctx_blocks, const float * k_scale, const float * v_scale) {
2381+ // clang-format on
23542382 constexpr int NWARPS = NUM_THREADS / WARP_SIZE; // 8 warps on gfx11
23552383 const int warpid = threadIdx .x / WARP_SIZE;
23562384 const int laneid = threadIdx .x % WARP_SIZE;
@@ -2360,6 +2388,12 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
23602388 const int rowid = laneid / 16 ;
23612389
23622390 const int seq_idx = blockIdx .x ;
2391+ // NOTE queries with sequence len > 1 are prefills and taken care by another
2392+ // kernel.
2393+ if (query_start_loc_ptr != nullptr &&
2394+ (query_start_loc_ptr[seq_idx + 1 ] - query_start_loc_ptr[seq_idx] != 1 )) {
2395+ return ;
2396+ }
23632397 const int partition_idx = blockIdx .y ;
23642398
23652399 constexpr int T_PAR_SIZE = 256 ; // token partition size set to 256
@@ -2419,11 +2453,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
24192453 // output layout from QKwmma : QH16xT8x2 16 qheads across 16 lanes, 16 tokens
24202454 // across 2 rows x 8 tokens per lane
24212455
2456+ const int64_t query_start_off = static_cast <int64_t >(
2457+ query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2458+
24222459 if (GQA_RATIO == 1 ) {
24232460 const int local_qhead_idx = lane16id % GQA_RATIO;
24242461 const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
2425- const int64_t seq_idx64 = static_cast <int64_t >(seq_idx);
2426- const scalar_t * q_ptr = q + seq_idx64 * q_stride +
2462+ const scalar_t * q_ptr = q + query_start_off * q_stride +
24272463 global_qhead_idx * HEAD_SIZE +
24282464 rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
24292465 if (lane16id < GQA_RATIO) {
@@ -2439,9 +2475,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
24392475 // fetch Q in shared across warps and then write to registers
24402476 const int local_qhead_idx = 2 * warpid + rowid;
24412477 const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
2442- const int64_t seq_idx64 = static_cast <int64_t >(seq_idx);
24432478 const scalar_t * q_ptr =
2444- q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
2479+ q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;
24452480
24462481 const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
24472482 if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
@@ -2736,6 +2771,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
27362771 const int num_kv_heads, const float scale,
27372772 const int * __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
27382773 const int * __restrict__ context_lens, // [num_seqs]
2774+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
27392775 const int max_num_blocks_per_seq,
27402776 const float * __restrict__ alibi_slopes, // [num_heads]
27412777 const int q_stride, const int kv_block_stride, const int kv_head_stride,
@@ -2762,15 +2798,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
27622798 const scalar_t * __restrict__ tmp_out, // [num_seqs, num_heads,
27632799 // max_num_partitions, head_size]
27642800 const int * __restrict__ context_lens, // [num_seqs]
2801+ const int * __restrict__ query_start_loc_ptr, // [num_seqs]
27652802 const int max_num_partitions) {
2766- const int num_heads = gridDim .x ;
2767- const int head_idx = blockIdx .x ;
2768- const int seq_idx = blockIdx .y ;
2803+ const auto num_heads = gridDim .x ;
2804+ const auto head_idx = blockIdx .x ;
2805+ const auto seq_idx = blockIdx .y ;
2806+
2807+ // NOTE queries with sequence len > 1 are prefills and taken care by another
2808+ // kernel.
2809+ if (query_start_loc_ptr != nullptr &&
2810+ (query_start_loc_ptr[seq_idx + 1 ] - query_start_loc_ptr[seq_idx] != 1 )) {
2811+ return ;
2812+ }
2813+
27692814 const int context_len = context_lens[seq_idx];
27702815 const int num_partitions = DIVIDE_ROUND_UP (context_len, PARTITION_SIZE);
2771- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
2816+ [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
27722817 const int warpid = threadIdx .x / WARP_SIZE;
2773- const int laneid = threadIdx .x % WARP_SIZE;
2818+ [[maybe_unused]] const int laneid = threadIdx .x % WARP_SIZE;
27742819
27752820 __shared__ float shared_global_exp_sum;
27762821 // max num partitions supported is warp_size * NPAR_LOOPS
@@ -2933,7 +2978,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
29332978 const float inv_global_exp_sum =
29342979 __fdividef (1 .0f , shared_global_exp_sum + 1e-6f );
29352980 acc *= inv_global_exp_sum;
2936- OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
2981+
2982+ const int64_t query_start_off = static_cast <int64_t >(
2983+ query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
2984+ OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
2985+ static_cast <int64_t >(head_idx) * HEAD_SIZE;
29372986 out_ptr[threadIdx .x ] = from_float<scalar_t >(acc);
29382987}
29392988
@@ -3201,16 +3250,24 @@ void paged_attention_custom_launcher_navi(
32013250 torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
32023251 torch::Tensor& value_cache, const int num_kv_heads, float scale,
32033252 torch::Tensor& block_tables, torch::Tensor& context_lens,
3204- int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
3205- torch::Tensor& k_scale, torch::Tensor& v_scale) {
3206- int num_seqs = query.size (0 );
3253+ const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
3254+ const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
3255+ torch::Tensor& v_scale) {
3256+ int num_seqs = block_tables.size (0 );
32073257 int num_heads = query.size (1 );
32083258 int head_size = query.size (2 );
32093259 int max_num_blocks_per_seq = block_tables.size (1 );
32103260 int q_stride = query.stride (0 );
32113261 int kv_block_stride = key_cache.stride (0 );
32123262 int kv_head_stride = key_cache.stride (1 );
32133263
3264+ // NOTE: query start location is optional for V0 decode should not be used.
3265+ // If batch contains mix of prefills and decode, prefills should be skipped.
3266+ const int * query_start_loc_ptr =
3267+ query_start_loc
3268+ ? reinterpret_cast <const int *>(query_start_loc.value ().data_ptr ())
3269+ : nullptr ;
3270+
32143271 // NOTE: Navi does not support alibi_slopes.
32153272 const float * alibi_slopes_ptr = nullptr ;
32163273
@@ -3363,14 +3420,14 @@ void paged_attention_custom_launcher_navi(
33633420 paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
33643421 PSIZE, ALIBI_ENABLED>( \
33653422 out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
3366- num_kv_heads, scale, block_tables, context_lens, max_context_len , \
3367- alibi_slopes, k_scale, v_scale); \
3423+ num_kv_heads, scale, block_tables, context_lens, query_start_loc , \
3424+ max_context_len, alibi_slopes, k_scale, v_scale); \
33683425 } else { \
33693426 paged_attention_custom_launcher_navi<T, KVT, KV_DTYPE, BLK_SIZE, \
33703427 HEAD_SIZE, T, PSIZE, ALIBI_ENABLED>( \
33713428 out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
3372- num_kv_heads, scale, block_tables, context_lens, max_context_len , \
3373- alibi_slopes, k_scale, v_scale); \
3429+ num_kv_heads, scale, block_tables, context_lens, query_start_loc , \
3430+ max_context_len, alibi_slopes, k_scale, v_scale); \
33743431 }
33753432
33763433#define CALL_CUSTOM_LAUNCHER_ALIBI (T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
0 commit comments