Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_check_pos_encoding_mode,
check_shape_dtype_device,
_get_cache_alibi_slopes_buf,
_get_sink_buf,
_get_cache_buf,
_get_range_buf,
_unpack_paged_kv_cache,
Expand Down Expand Up @@ -242,6 +243,7 @@ def run_batch_decode(
window_left: int,
enable_pdl: bool,
alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
Expand All @@ -263,6 +265,7 @@ def run_batch_decode(
window_left,
enable_pdl,
alibi_slopes,
maybe_s_aux,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
Expand All @@ -286,6 +289,7 @@ def _fake_run_batch_decode(
window_left: int,
enable_pdl: bool,
alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
Expand Down Expand Up @@ -384,6 +388,7 @@ def single_decode_with_kv_cache(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
return_lse: Literal[True] = True,
sinks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ...


Expand All @@ -403,6 +408,7 @@ def single_decode_with_kv_cache(
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
return_lse: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
r"""Decode attention with KV Cache for single request, return attention output.

Expand Down Expand Up @@ -529,6 +535,7 @@ def single_decode_with_kv_cache(
window_left,
None, # packed_custom_mask
_get_cache_alibi_slopes_buf(num_qo_heads, q.device),
sinks, # maybe_s_aux
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
Expand Down Expand Up @@ -1330,7 +1337,7 @@ def run(
self._kv_lens_buffer,
page_size,
self._max_kv_len,
sinks,
_get_sink_buf(sinks),
]

self._cached_module.paged_run(*run_args)
Expand Down Expand Up @@ -1364,6 +1371,7 @@ def run(
else:
run_args += [
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
_get_sink_buf(sinks),
logits_soft_cap,
sm_scale,
rope_scale,
Expand Down
16 changes: 10 additions & 6 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def gen_single_decode_module(
dtype_o,
head_dim_qk,
head_dim_vo,
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
"logits_soft_cap",
"sm_scale",
Expand Down Expand Up @@ -516,8 +516,12 @@ def gen_single_prefill_module(

if backend == "fa2":
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"]
additional_tensor_dtypes = ["uint8_t", "float"]
additional_tensor_names = [
"maybe_custom_mask",
"maybe_alibi_slopes",
"maybe_s_aux",
]
additional_tensor_dtypes = ["uint8_t", "float", "float"]
Comment on lines +519 to +524
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that single prefill Python wrappers accept and pass maybe_s_aux parameter.

The single prefill module (fa2 backend) now includes maybe_s_aux as the third additional tensor parameter. Ensure that Python wrappers for single prefill operations (e.g., flashinfer.single_prefill_with_kv_cache, flashinfer/prefill.py:get_single_prefill_module) have been updated to accept and forward the sinks parameter.

Run the following script to verify the prefill wrapper signatures:


🏁 Script executed:

#!/bin/bash
# Description: Check if Python wrappers for single prefill have been updated to handle sinks parameter

echo "=== Checking single_prefill_with_kv_cache function signature ==="
rg -n -A5 'def single_prefill_with_kv_cache' --type=py

echo ""
echo "=== Checking get_single_prefill_module calls ==="
rg -n -B3 -A10 'get_single_prefill_module' --type=py

echo ""
echo "=== Checking for sinks parameter in prefill.py ==="
rg -n -C5 'sinks' python/flashinfer/prefill.py 2>/dev/null || rg -n -C5 'sinks' flashinfer/prefill.py

Length of output: 10216


🏁 Script executed:

#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions

echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py

echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.py

Length of output: 671


🏁 Script executed:

#!/bin/bash
# Get full function signatures for single_prefill_with_kv_cache functions

echo "=== Full signature of single_prefill_with_kv_cache at line 879 ==="
sed -n '879,920p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 914 ==="
sed -n '914,960p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 939 ==="
sed -n '939,985p' flashinfer/prefill.py

echo ""
echo "=== Full signature of single_prefill_with_kv_cache at line 963 ==="
sed -n '963,1010p' flashinfer/prefill.py

echo ""
echo "=== Checking how module is called in single prefill (around line 1164) ==="
sed -n '1160,1210p' flashinfer/prefill.py

echo ""
echo "=== Checking run_single_prefill function signature and how it calls module ==="
sed -n '268,310p' flashinfer/prefill.py

Length of output: 9560


🏁 Script executed:

#!/bin/bash
# Check the complete module.run() call in single_prefill_with_kv_cache implementation

echo "=== Complete module.run() call (lines 1176-1195) ==="
sed -n '1176,1195p' flashinfer/prefill.py

echo ""
echo "=== Verify run_single_prefill accepts maybe_s_aux (around line 278) ==="
sed -n '278,290p' flashinfer/prefill.py

echo ""
echo "=== Check if run_single_prefill actually uses maybe_s_aux in fa2 backend ==="
sed -n '310,350p' flashinfer/prefill.py

Length of output: 2182


Add maybe_s_aux parameter to single prefill Python wrapper and pass it to module.run().

The single_prefill_with_kv_cache function does not accept or pass the maybe_s_aux parameter that the underlying generated C++ module expects. The custom operation run_single_prefill (line 278) accepts maybe_s_aux and passes it to the fa2 backend (line 329), but the Python wrapper never provides it. This creates a parameter mismatch.

Update single_prefill_with_kv_cache:

  1. Add maybe_s_aux: Optional[torch.Tensor] = None parameter (matching batch prefill pattern)
  2. Pass it to module.run() after maybe_alibi_slopes and before logits_soft_cap (position 11)
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 519 to 524, the Python
wrapper single_prefill_with_kv_cache does not accept or forward the maybe_s_aux
tensor that the generated C++ module expects; add a parameter maybe_s_aux:
Optional[torch.Tensor] = None to the function signature (matching the batch
prefill pattern) and include that variable in the module.run(...) call argument
list immediately after maybe_alibi_slopes and before logits_soft_cap (i.e., as
the 11th positional argument) so the wrapper matches the underlying
run_single_prefill/fa2 backend usage.

additional_scalar_names = [
"logits_soft_cap",
"sm_scale",
Expand Down Expand Up @@ -760,8 +764,8 @@ def gen_batch_decode_module(
dtype_idx,
head_dim_qk,
head_dim_vo,
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
"logits_soft_cap",
"sm_scale",
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def run_single_prefill(
window_left: int,
maybe_packed_custom_mask: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
scale_q: Optional[torch.Tensor],
Expand Down Expand Up @@ -330,6 +331,7 @@ def run_single_prefill(
window_left,
maybe_packed_custom_mask,
maybe_alibi_slopes,
maybe_s_aux,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
Expand All @@ -350,6 +352,7 @@ def _fake_run_single_prefill(
window_left: int,
maybe_packed_custom_mask: Optional[torch.Tensor],
maybe_alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
Expand Down
17 changes: 17 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def _get_cache_alibi_slopes_buf(
return buf


def _get_sink_buf(
sinks: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Convert sinks tensor to proper format for CUDA kernels.

Args:
sinks: Optional tensor of shape [num_qo_heads] with sink values per head

Returns:
Contiguous float32 tensor or None if sinks is None
"""
if sinks is None:
return None
# Ensure it's float32 and contiguous as expected by CUDA kernels
return sinks.to(torch.float32).contiguous()


def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
return getattr(torch, dtype)
Expand Down
16 changes: 16 additions & 0 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast<float*>(smem), smem_md,
tx, ty, tz);
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
Comment on lines +358 to +365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for adding the sink contribution is duplicated in BatchDecodeWithPagedKVCacheDevice (lines 600-607). To improve maintainability and reduce code duplication, consider extracting this block into a helper function.

Also, the constant LOG2_E is defined inline here and in the other location. It would be better to define it once at the top of the file in an anonymous namespace to avoid magic numbers and ensure consistency.

For example, you could add at the top of the file:

namespace flashinfer {

namespace { // anonymous namespace

static constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)

template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
                                                    const Params& params,
                                                    uint32_t qo_head_idx) {
  if constexpr (variant.use_softmax) {
    if (params.maybe_s_aux != nullptr) {
      float s_aux_val = params.maybe_s_aux[qo_head_idx];
      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
    }
  }
}

} // anonymous namespace

// ... rest of the file

Then you could replace this block and the one in BatchDecodeWithPagedKVCacheDevice with a call to this helper function:

AddSinkContribution(variant, st_local, params, qo_head_idx);

Comment on lines +358 to +365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix s_aux scaling to match logits path.

s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E.

Apply:

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
}
}
🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 358 to 365, the s_aux
contribution is being added using (s_aux - m) * LOG2_E which mismatches the
logits path scaling; change the computation to multiply (s_aux_val - st_local.m)
by variant.sm_scale_log2 and remove LOG2_E so the call becomes
math::ptx_exp2((s_aux_val - st_local.m) * variant.sm_scale_log2); keep the same
null check and use_softmax guard.

#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
st_local.o[i] = variant.OutputTransform(params, st_local.o[i], /*batch_idx=*/0, /*qo_idx=*/0,
Expand Down Expand Up @@ -589,6 +597,14 @@ __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& param
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast<float*>(smem), smem_md, tx, ty,
tz);
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
Comment on lines +601 to +607
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same scaling fix for batch kernel.

Mirror the s_aux scaling correction here to keep behavior consistent across kernels.

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st.d += math::ptx_exp2(s_aux_scaled - st.m);
+    }
+  }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 601-607, the
batch-kernel branch needs the same s_aux scaling fix as the non-batch path:
compute LOG2_E = 1.4426950408889634f, read s_aux_val =
params.maybe_s_aux[qo_head_idx], multiply (s_aux_val - st.m) by LOG2_E and pass
that to math::ptx_exp2, then add the result to st.d so the auxiliary scaling
matches the other kernel.

#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
st.o[i] = variant.OutputTransform(params, st.o[i], bx, /*qo_idx=*/0, qo_head_idx, st.m, st.d,
Expand Down
6 changes: 6 additions & 0 deletions include/flashinfer/attention/default_decode_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct SingleDecodeParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint32_t kv_len;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
Expand All @@ -58,6 +59,7 @@ struct SingleDecodeParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
kv_len(0),
num_qo_heads(0),
num_kv_heads(0),
Expand All @@ -84,6 +86,7 @@ struct SingleDecodeParams {
o(o),
lse(nullptr),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(nullptr),
kv_len(seq_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
Expand Down Expand Up @@ -118,6 +121,7 @@ struct BatchDecodeParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint32_t padded_batch_size;
uint32_t num_qo_heads;
IdType q_stride_n;
Expand All @@ -142,6 +146,7 @@ struct BatchDecodeParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
padded_batch_size(0),
num_qo_heads(0),
q_stride_n(0),
Expand Down Expand Up @@ -170,6 +175,7 @@ struct BatchDecodeParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(nullptr),
padded_batch_size(0),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
Expand Down
23 changes: 16 additions & 7 deletions include/flashinfer/attention/default_prefill_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct SinglePrefillParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint_fastdiv group_size;
uint32_t qo_len;
uint32_t kv_len;
Expand Down Expand Up @@ -66,6 +67,7 @@ struct SinglePrefillParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
group_size(),
qo_len(0),
kv_len(0),
Expand All @@ -86,7 +88,7 @@ struct SinglePrefillParams {
partition_kv(false) {}

__host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
DTypeO* o, float* lse, float* maybe_alibi_slopes, float* maybe_s_aux,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len,
uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim,
Expand All @@ -99,6 +101,7 @@ struct SinglePrefillParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(maybe_s_aux),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
Expand Down Expand Up @@ -146,6 +149,7 @@ struct BatchPrefillRaggedParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint_fastdiv group_size;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
Expand Down Expand Up @@ -190,6 +194,7 @@ struct BatchPrefillRaggedParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
group_size(),
num_qo_heads(0),
num_kv_heads(0),
Expand Down Expand Up @@ -224,9 +229,9 @@ struct BatchPrefillRaggedParams {
IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr,
IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n,
uint32_t kv_stride_h, int32_t window_left,
float* maybe_s_aux, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta)
: q(q),
Expand All @@ -241,6 +246,7 @@ struct BatchPrefillRaggedParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(maybe_s_aux),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
Expand Down Expand Up @@ -296,6 +302,7 @@ struct BatchPrefillPagedParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint_fastdiv group_size;
uint32_t num_qo_heads;
IdType q_stride_n;
Expand Down Expand Up @@ -332,6 +339,7 @@ struct BatchPrefillPagedParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
group_size(),
num_qo_heads(0),
q_stride_n(0),
Expand Down Expand Up @@ -361,9 +369,9 @@ struct BatchPrefillPagedParams {
uint8_t* maybe_custom_mask, IdType* q_indptr,
IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset,
DTypeO* o, float* lse, float* maybe_alibi_slopes,
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta)
float* maybe_s_aux, uint32_t num_qo_heads, IdType q_stride_n,
IdType q_stride_h, int32_t window_left, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta)
: q(q),
paged_kv(paged_kv),
maybe_custom_mask(maybe_custom_mask),
Expand All @@ -373,6 +381,7 @@ struct BatchPrefillPagedParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(maybe_s_aux),
group_size(num_qo_heads / paged_kv.num_heads),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
Expand Down
10 changes: 10 additions & 0 deletions include/flashinfer/attention/variants.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ struct DefaultAttention : AttentionVariantBase {
}
return mask;
})

REGISTER_M_D_UPDATE(params, kv_tile_idx, qo_head_idx, m, d, scale, {
if constexpr (use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
d += math::ptx_exp2((s_aux_val - m) * LOG2_E);
}
}
})
};

}; // namespace flashinfer
Expand Down
Loading