-
Notifications
You must be signed in to change notification settings - Fork 584
feat: add sink to flashinfer decode #2087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -355,6 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par | |||||||||||||||||||||||||||||||
| // sync local state of all warps inside a threadblock | ||||||||||||||||||||||||||||||||
| sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast<float*>(smem), smem_md, | ||||||||||||||||||||||||||||||||
| tx, ty, tz); | ||||||||||||||||||||||||||||||||
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||||||||||||||||||||||||||||||||
| if constexpr (variant.use_softmax) { | ||||||||||||||||||||||||||||||||
| if (params.maybe_s_aux != nullptr) { | ||||||||||||||||||||||||||||||||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||||||||||||||||||||||||||||||||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||||||||||||||||||||||||||||||||
| st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
Comment on lines
+358
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for adding the sink contribution is duplicated in Also, the constant For example, you could add at the top of the file: namespace flashinfer {
namespace { // anonymous namespace
static constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
const Params& params,
uint32_t qo_head_idx) {
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
}
} // anonymous namespace
// ... rest of the fileThen you could replace this block and the one in AddSinkContribution(variant, st_local, params, qo_head_idx);
Comment on lines
+358
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix s_aux scaling to match logits path. s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E. Apply: - if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+ }
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < vec_size; ++i) { | ||||||||||||||||||||||||||||||||
| st_local.o[i] = variant.OutputTransform(params, st_local.o[i], /*batch_idx=*/0, /*qo_idx=*/0, | ||||||||||||||||||||||||||||||||
|
|
@@ -589,6 +597,14 @@ __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& param | |||||||||||||||||||||||||||||||
| // sync local state of all warps inside a threadblock | ||||||||||||||||||||||||||||||||
| sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast<float*>(smem), smem_md, tx, ty, | ||||||||||||||||||||||||||||||||
| tz); | ||||||||||||||||||||||||||||||||
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||||||||||||||||||||||||||||||||
| if constexpr (variant.use_softmax) { | ||||||||||||||||||||||||||||||||
| if (params.maybe_s_aux != nullptr) { | ||||||||||||||||||||||||||||||||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||||||||||||||||||||||||||||||||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||||||||||||||||||||||||||||||||
| st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
Comment on lines
+601
to
+607
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same scaling fix for batch kernel. Mirror the s_aux scaling correction here to keep behavior consistent across kernels. - if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st.d += math::ptx_exp2(s_aux_scaled - st.m);
+ }
+ }
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < vec_size; ++i) { | ||||||||||||||||||||||||||||||||
| st.o[i] = variant.OutputTransform(params, st.o[i], bx, /*qo_idx=*/0, qo_head_idx, st.m, st.d, | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that single prefill Python wrappers accept and pass
maybe_s_auxparameter.The single prefill module (fa2 backend) now includes
maybe_s_auxas the third additional tensor parameter. Ensure that Python wrappers for single prefill operations (e.g.,flashinfer.single_prefill_with_kv_cache,flashinfer/prefill.py:get_single_prefill_module) have been updated to accept and forward thesinksparameter.Run the following script to verify the prefill wrapper signatures:
🏁 Script executed:
Length of output: 10216
🏁 Script executed:
Length of output: 671
🏁 Script executed:
Length of output: 9560
🏁 Script executed:
Length of output: 2182
Add
maybe_s_auxparameter to single prefill Python wrapper and pass it to module.run().The
single_prefill_with_kv_cachefunction does not accept or pass themaybe_s_auxparameter that the underlying generated C++ module expects. The custom operationrun_single_prefill(line 278) acceptsmaybe_s_auxand passes it to the fa2 backend (line 329), but the Python wrapper never provides it. This creates a parameter mismatch.Update
single_prefill_with_kv_cache:maybe_s_aux: Optional[torch.Tensor] = Noneparameter (matching batch prefill pattern)module.run()aftermaybe_alibi_slopesand beforelogits_soft_cap(position 11)🤖 Prompt for AI Agents