Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
009a4ed
Initial stub for prefill.cuh
diptorupd Aug 2, 2025
4712039
Add new utils.cuh
diptorupd Aug 2, 2025
15dcca9
Add dispatch.cuh
diptorupd Aug 2, 2025
632c1b6
Unit test cases for data access patterns
diptorupd Aug 2, 2025
7bb43d1
Updated test_produce_kv.cpp
diptorupd Aug 3, 2025
4667724
Add mma ops
rtmadduri Aug 3, 2025
2c8e6d6
Various initial changes to fix build issues for generic/prefill.cuh
diptorupd Aug 4, 2025
1d5bf96
A standalone driver for singleprefill
diptorupd Aug 4, 2025
ea3791e
Completed Mechanical HIPification changes.
diptorupd Aug 4, 2025
3e76b2c
Updated standalone example
diptorupd Aug 4, 2025
87a588b
Merge remote-tracking branch 'refs/remotes/origin/feature/hipified_pr…
diptorupd Aug 4, 2025
0ed875a
Updated load_q_global_smem.
diptorupd Aug 5, 2025
b7621c6
Port produce_kv to HIP
diptorupd Aug 5, 2025
a29fe25
Update KernelTraits
diptorupd Aug 5, 2025
dc05fe2
Ported query rope transformation to MI300.
diptorupd Aug 5, 2025
de13cd0
WIP changes to compute_qk
diptorupd Aug 6, 2025
cf8bd74
WIP
diptorupd Aug 7, 2025
28aa718
WIP2
diptorupd Aug 7, 2025
6dd2063
Working transpose test
diptorupd Aug 7, 2025
2ffa2b1
Optimizations for transposed loads
diptorupd Aug 7, 2025
b8e617a
transpose 4x4 test case.
diptorupd Aug 7, 2025
5de3797
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Aug 7, 2025
2c00391
various compilation error fixes
diptorupd Aug 8, 2025
c418cdc
Compilation fixes
diptorupd Aug 8, 2025
fbe0a77
Various compiler error fixes.
diptorupd Aug 8, 2025
cffa9dd
Fix wrong header guard
diptorupd Aug 8, 2025
3c9a5de
port llama rotary transforms to HIP.
diptorupd Aug 10, 2025
548faa2
Port ancillary kernels to CDNA3 thread layout.
diptorupd Aug 10, 2025
83fd9c9
wip
diptorupd Aug 10, 2025
5531105
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Aug 12, 2025
eb40006
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Aug 12, 2025
06b25c0
Fix merge issue
diptorupd Aug 12, 2025
a9a4df1
Upadet compute_qk to use mma ops
diptorupd Aug 12, 2025
7ef584b
Update all kernel launch to use WARP_SIZe for thread count.
diptorupd Aug 12, 2025
f51c30e
Update CUDA path in compute_qk.
diptorupd Aug 12, 2025
d3377ce
WIP...
diptorupd Aug 12, 2025
42b3965
Implementation of update_mdo_states.
diptorupd Aug 13, 2025
16a9d15
WIP compute_sfm_v
diptorupd Aug 13, 2025
96ead84
WIP broken...
diptorupd Aug 13, 2025
08f5ed1
Ported threadblock_sync_mdo_states
diptorupd Aug 14, 2025
7893ddd
write_o_reg_smem ported
diptorupd Aug 14, 2025
8524d22
Prefill compiles...
diptorupd Aug 14, 2025
931e43b
Fixed logits functions and load_q_global_smem bug.
diptorupd Aug 15, 2025
05a3472
Indexing fixes to normalize_d
diptorupd Aug 16, 2025
f21615b
Properly fix upcats_size and silence warnings
diptorupd Aug 16, 2025
db59b65
Add default_prefill_params.cuh to generic
diptorupd Aug 17, 2025
1c161e5
Initial unit test harness
diptorupd Aug 17, 2025
9e2e3f5
Updated load_q_global_smem_kernel test
diptorupd Aug 17, 2025
78b559f
Fixes to test_single_prefill.cpp
diptorupd Aug 18, 2025
86a9780
Initial stubs for compute_qk
diptorupd Aug 18, 2025
6003a3f
Standalone tester for compute_qk
diptorupd Aug 18, 2025
f37d407
Testing Q read logic
diptorupd Aug 19, 2025
fb714ee
Updated produce_kv
diptorupd Aug 22, 2025
1d60d18
Fix compiler warnings.
diptorupd Aug 22, 2025
ead9a21
Fix k_smem_offset_rcalc
diptorupd Aug 22, 2025
89cfce4
Fix init_rope_freq.
diptorupd Aug 26, 2025
712b8ed
Debug
diptorupd Aug 27, 2025
07a7e64
Debug llama
diptorupd Aug 27, 2025
e337e28
utils
rtmadduri Aug 28, 2025
5b9daa5
testing harness for compute_sfm
rtmadduri Aug 28, 2025
e3b770c
Debug
diptorupd Aug 28, 2025
4d849ee
Merge remote-tracking branch 'origin/feature/hipified_prefill_v3' int…
diptorupd Aug 28, 2025
ee67970
Debugging changes
diptorupd Aug 28, 2025
7137d9c
Debugging
diptorupd Aug 28, 2025
a0a57ab
verified q, k logic
rtmadduri Aug 28, 2025
ca902d3
Debugging produce_kv
diptorupd Sep 1, 2025
ea74445
Debug....
diptorupd Sep 2, 2025
e67ea64
Debug produce_kv_impl_cdna3_
diptorupd Sep 3, 2025
a6633f0
Fixed produce_kv_impl_cdna3_
diptorupd Sep 3, 2025
9ef0a4b
Off-by-one error
diptorupd Sep 3, 2025
b9b34b2
Changes to logits functions
diptorupd Sep 5, 2025
37bfad7
Fix produce_kv
diptorupd Sep 6, 2025
8abae3d
Test transforms debug
diptorupd Sep 8, 2025
34e99f8
wip fixes
diptorupd Sep 9, 2025
e9aec3d
sfrag debug writer
diptorupd Sep 10, 2025
7da105d
sfrag more debugging
diptorupd Sep 10, 2025
1b79426
Revert update_mod_changes
diptorupd Sep 11, 2025
2ae831c
Debugger for sfrag using pandas
diptorupd Sep 12, 2025
20d7e2a
Tester scripts
diptorupd Sep 12, 2025
3b32ec0
Add warp-level debug prints for sfrag.
diptorupd Sep 12, 2025
6041be1
More debugging
diptorupd Sep 13, 2025
06faaab
Fixed compute_qk
diptorupd Sep 13, 2025
496aaaf
wip debugging of softmax
diptorupd Sep 15, 2025
4a06c32
Formatting
diptorupd Sep 15, 2025
7d33094
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Sep 15, 2025
9082889
Update clang-format to match upstream.
diptorupd Sep 16, 2025
2c99536
Reformat libflashinfer
diptorupd Sep 16, 2025
41f340a
Reformat flashinfer/csrc
diptorupd Sep 16, 2025
1877e0e
merge amd-integration
diptorupd Sep 16, 2025
5c1e0b9
Merge amd-integration
diptorupd Sep 16, 2025
28327bb
Remove leftover file
diptorupd Sep 16, 2025
4b9fc6b
merge amd-ntegration
diptorupd Sep 16, 2025
ef5f6a1
rever frag_layout_swizzle.cuh
diptorupd Sep 16, 2025
6ff963a
Reformat prefill.cuh
diptorupd Sep 16, 2025
3e0fa2a
Silence warnings
diptorupd Sep 16, 2025
4c8e574
Remove redundant files
diptorupd Sep 17, 2025
3ce1fa6
Fixes
diptorupd Sep 17, 2025
e0c8dc0
Formatting
diptorupd Sep 17, 2025
57dac70
Fix fragment loading
diptorupd Sep 26, 2025
c9b2d83
Fixes
diptorupd Sep 26, 2025
25f40d9
WIP
diptorupd Sep 28, 2025
98bdf4b
Remove redundant fuction
diptorupd Sep 28, 2025
55e0481
Precommit fixes
diptorupd Oct 1, 2025
f0264e4
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Oct 1, 2025
68d1bc2
WIP
diptorupd Oct 2, 2025
2386341
Revert changes to wrong file
diptorupd Oct 2, 2025
bd598c1
Update from amd-integration
diptorupd Oct 2, 2025
f0e8e72
Update from amd-integration
diptorupd Oct 2, 2025
4476fd0
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Oct 6, 2025
0b9db15
Compilation fixes
diptorupd Oct 6, 2025
87e6f55
Improved debugging
diptorupd Oct 9, 2025
28a0355
Improved s_frag debug printer
diptorupd Oct 9, 2025
66dbca4
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Oct 10, 2025
b77748d
Remove unused debug function
diptorupd Oct 10, 2025
3b52620
Clean up tests
diptorupd Oct 10, 2025
3e4e12b
Remove more unused stuff
diptorupd Oct 10, 2025
3836381
Remove unused file
diptorupd Oct 10, 2025
f7744cb
Merge branch 'amd-integration' into feature/hipified_prefill_v3
diptorupd Oct 15, 2025
893e438
Debug prints
diptorupd Oct 20, 2025
ed30bf6
Validated s_frag and m value calcs in online softmax
diptorupd Oct 21, 2025
75bed47
Remove temorary scripts
diptorupd Oct 21, 2025
d4b3e13
Add instrumentation to validate compute_sfm_v
diptorupd Oct 22, 2025
6be1fb7
Remove old test cases
diptorupd Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,395 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_PREFILL_PARAMS_CUH_
#define FLASHINFER_PREFILL_PARAMS_CUH_

#include <cmath>
#include <cstdint>

#include "gpu_iface/gpu_runtime_compat.hpp"
#include "page.cuh"

namespace flashinfer {

template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_>
struct SinglePrefillParams {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = int32_t;
DTypeQ* q;
DTypeKV* k;
DTypeKV* v;
uint8_t* maybe_custom_mask;
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t qo_len;
uint32_t kv_len;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t k_stride_n;
uint32_t k_stride_h;
uint32_t v_stride_n;
uint32_t v_stride_h;
uint32_t head_dim;
int32_t window_left;
float logits_soft_cap;
float sm_scale;
float rope_rcp_scale;
float rope_rcp_theta;
uint32_t debug_thread_id;
uint32_t debug_warp_id;

uint32_t partition_kv;

__host__ SinglePrefillParams()
: q(nullptr),
k(nullptr),
v(nullptr),
maybe_custom_mask(nullptr),
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
qo_len(0),
kv_len(0),
num_qo_heads(0),
num_kv_heads(0),
q_stride_n(0),
q_stride_h(0),
k_stride_n(0),
k_stride_h(0),
v_stride_n(0),
v_stride_h(0),
head_dim(0),
window_left(0),
logits_soft_cap(0.0f),
sm_scale(0.0f),
rope_rcp_scale(0.0f),
rope_rcp_theta(0.0f),
partition_kv(false) {}

__host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim,
int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, uint32_t debug_thread_id,
uint32_t debug_warp_id)
: q(q),
k(k),
v(v),
maybe_custom_mask(maybe_custom_mask),
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
qo_len(qo_len),
kv_len(kv_len),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
k_stride_n(kv_stride_n),
k_stride_h(kv_stride_h),
v_stride_n(kv_stride_n),
v_stride_h(kv_stride_h),
head_dim(head_dim),
window_left(window_left),
logits_soft_cap(logits_soft_cap),
sm_scale(sm_scale),
rope_rcp_scale(1. / rope_scale),
rope_rcp_theta(1. / rope_theta),
debug_thread_id(debug_thread_id),
debug_warp_id(debug_warp_id),
partition_kv(false) {}

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return qo_len;
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_len;
}
};

template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
struct BatchPrefillRaggedParams {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;

DTypeQ* q;
DTypeKV* k;
DTypeKV* v;
uint8_t* maybe_custom_mask;
IdType* q_indptr;
IdType* kv_indptr;
IdType* maybe_mask_indptr;
IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for
// fused-rope attention
IdType* maybe_k_rope_offset; // maybe_k_rope_offset is only used for
// fused-rope attention
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t k_stride_n;
uint32_t k_stride_h;
uint32_t v_stride_n;
uint32_t v_stride_h;
int32_t window_left;
float logits_soft_cap;
float sm_scale;
float rope_rcp_scale;
float rope_rcp_theta;

IdType* request_indices;
IdType* qo_tile_indices;
IdType* kv_tile_indices;
IdType* merge_indptr;
IdType* o_indptr;
IdType* kv_chunk_size_ptr;
bool* block_valid_mask;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

__host__ BatchPrefillRaggedParams()
: q(nullptr),
k(nullptr),
v(nullptr),
maybe_custom_mask(nullptr),
q_indptr(nullptr),
kv_indptr(nullptr),
maybe_mask_indptr(nullptr),
maybe_q_rope_offset(nullptr),
maybe_k_rope_offset(nullptr),
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
num_qo_heads(0),
num_kv_heads(0),
q_stride_n(0),
q_stride_h(0),
k_stride_n(0),
k_stride_h(0),
v_stride_n(0),
v_stride_h(0),
window_left(0),
logits_soft_cap(0.0f),
sm_scale(0.0f),
rope_rcp_scale(0.0f),
rope_rcp_theta(0.0f),
request_indices(nullptr),
qo_tile_indices(nullptr),
kv_tile_indices(nullptr),
merge_indptr(nullptr),
o_indptr(nullptr),
kv_chunk_size_ptr(nullptr),
block_valid_mask(nullptr),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

__host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask,
IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr,
IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n,
uint32_t kv_stride_h, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta)
: q(q),
k(k),
v(v),
maybe_custom_mask(maybe_custom_mask),
q_indptr(q_indptr),
kv_indptr(kv_indptr),
maybe_mask_indptr(maybe_mask_indptr),
maybe_q_rope_offset(maybe_q_rope_offset),
maybe_k_rope_offset(maybe_k_rope_offset),
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
k_stride_n(kv_stride_n),
k_stride_h(kv_stride_h),
v_stride_n(kv_stride_n),
v_stride_h(kv_stride_h),
window_left(window_left),
logits_soft_cap(logits_soft_cap),
sm_scale(sm_scale),
rope_rcp_scale(1.f / rope_scale),
rope_rcp_theta(1.f / rope_theta),
request_indices(nullptr),
qo_tile_indices(nullptr),
kv_tile_indices(nullptr),
merge_indptr(nullptr),
o_indptr(nullptr),
kv_chunk_size_ptr(nullptr),
block_valid_mask(nullptr),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx];
}
};

template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
struct BatchPrefillPagedParams {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;

DTypeQ* q;
paged_kv_t<DTypeKV, IdType> paged_kv;
uint8_t* maybe_custom_mask;
IdType* q_indptr;
IdType* maybe_mask_indptr;
IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for
// fused-rope attention
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
int32_t window_left;
float logits_soft_cap;
float sm_scale;
float rope_rcp_scale;
float rope_rcp_theta;

IdType* request_indices;
IdType* qo_tile_indices;
IdType* kv_tile_indices;
IdType* merge_indptr;
IdType* o_indptr;
bool* block_valid_mask;
IdType* kv_chunk_size_ptr;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

__host__ BatchPrefillPagedParams()
: q(nullptr),
paged_kv(),
maybe_custom_mask(nullptr),
q_indptr(nullptr),
maybe_mask_indptr(nullptr),
maybe_q_rope_offset(nullptr),
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
num_qo_heads(0),
q_stride_n(0),
q_stride_h(0),
window_left(0),
logits_soft_cap(0.0f),
sm_scale(0.0f),
rope_rcp_scale(0.0f),
rope_rcp_theta(0.0f),
request_indices(nullptr),
qo_tile_indices(nullptr),
kv_tile_indices(nullptr),
merge_indptr(nullptr),
o_indptr(nullptr),
block_valid_mask(nullptr),
kv_chunk_size_ptr(nullptr),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

__host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t<DTypeKV, IdType> paged_kv,
uint8_t* maybe_custom_mask, IdType* q_indptr,
IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta)
: q(q),
paged_kv(paged_kv),
maybe_custom_mask(maybe_custom_mask),
q_indptr(q_indptr),
maybe_mask_indptr(maybe_mask_indptr),
maybe_q_rope_offset(maybe_q_rope_offset),
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / paged_kv.num_heads),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
window_left(window_left),
logits_soft_cap(logits_soft_cap),
sm_scale(sm_scale),
rope_rcp_scale(1.f / rope_scale),
rope_rcp_theta(1.f / rope_theta),
request_indices(nullptr),
qo_tile_indices(nullptr),
kv_tile_indices(nullptr),
merge_indptr(nullptr),
o_indptr(nullptr),
block_valid_mask(nullptr),
kv_chunk_size_ptr(nullptr),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return paged_kv.get_length(batch_idx);
}
};

} // namespace flashinfer

#endif // FLASHINFER_DECODE_PARAMS_CUH_
Loading