Skip to content
142 changes: 106 additions & 36 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <flashinfer/trtllm/fmha/decoder_impl_common.h>
#include <flashinfer/trtllm/fmha/fmhaRunnerParams.h>
#include <nvrtc.h>
#include <tvm/ffi/container/variant.h>

#include <flashinfer/trtllm/fmha/fmhaRunner.cuh>
#include <flashinfer/utils.cuh>
Expand All @@ -28,6 +29,7 @@
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;
using tvm::ffi::Variant;

namespace flashinfer {

Expand Down Expand Up @@ -78,9 +80,10 @@ void trtllm_paged_attention_launcher(
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values,
int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq,
double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -117,8 +120,12 @@ void trtllm_paged_attention_launcher(
runner_params.vStrideBatch = kv_stride_batch;
runner_params.mNumPagesInMemPool = num_pages_in_mem_pool;
runner_params.stream = stream;
// the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and
// outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored
runner_params.outputScale = bmm2_scale;
runner_params.outputScalePtr = bmm2_scale_ptr;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr;
runner_params.oSfPtr = out_scale_factor;
runner_params.mSfStartTokenIdx = o_sf_start_index;
runner_params.mScaleSfO = o_sf_scale;
Expand Down Expand Up @@ -197,11 +204,12 @@ inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_T
void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_kv_len, double bmm1_scale,
double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size,
int64_t o_sf_start_index, int64_t window_left, int64_t sm_count,
bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks) {
TensorView seq_lens, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @tqchen , I suppose Variant is a legit ABI across languages right?

Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t window_left, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -250,7 +258,25 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
<< "attention_sinks must be a float tensor";
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}

auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>();
TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(),
"bmm1_scale must be either a double or a tensor");
TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(),
"bmm2_scale must be either a double or a tensor");
double bmm1_scale_value =
maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0;
double bmm2_scale_value =
maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0;
float* bmm1_scale_log2_ptr =
maybe_bmm1_scale_log2_tensor.has_value()
? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
: nullptr;
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;
Comment on lines +262 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard tensor-based scales with dtype checks

When bmm*_scale comes in as a tensor, we immediately reinterpret the storage as float*. Callers can legally hand us torch.Float16/torch.BFloat16 today, so this reinterpret cast will read garbage and corrupt the softmax/output scales. Please gate the tensor branch with a dtype == dl_float32 check (and emit a clear error otherwise) before taking the pointer, and apply the same fix in the context and ragged code paths.

@@
-  float* bmm1_scale_log2_ptr =
-      maybe_bmm1_scale_log2_tensor.has_value()
-          ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
-          : nullptr;
-  float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
-                              ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
-                              : nullptr;
+  float* bmm1_scale_log2_ptr = nullptr;
+  if (maybe_bmm1_scale_log2_tensor.has_value()) {
+    TVM_FFI_ICHECK_EQ(maybe_bmm1_scale_log2_tensor.value().dtype(), dl_float32)
+        << "bmm1_scale tensor must be float32";
+    bmm1_scale_log2_ptr =
+        static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr());
+  }
+  float* bmm2_scale_ptr = nullptr;
+  if (maybe_bmm2_scale_tensor.has_value()) {
+    TVM_FFI_ICHECK_EQ(maybe_bmm2_scale_tensor.value().dtype(), dl_float32)
+        << "bmm2_scale tensor must be float32";
+    bmm2_scale_ptr =
+        static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr());
+  }

Please mirror this guard in trtllm_paged_attention_context and trtllm_ragged_attention.

Also applies to: 338-356, 503-521

πŸ€– Prompt for AI Agents
csrc/trtllm_fmha_kernel_launcher.cu lines 260-277: when bmm1_scale or bmm2_scale
is a tensor the code currently reinterpret_casts data_ptr() to float* without
checking dtype which will misread half/bfloat tensors; modify the tensor branch
to first check the tensor dtype is float32 (dl_float32) and TVM_FFI_CHECK/throw
a clear error if not, then take the data_ptr() as float*; apply the identical
dtype-guard and error message to the similar blocks at lines 338-356 and 503-521
and also mirror these dtype guards in the corresponding
trtllm_paged_attention_context and trtllm_ragged_attention code paths.

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
Expand All @@ -259,21 +285,20 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale,
bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
enable_pdl, workspace_size, stream);
kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq,
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size,
stream);
}

void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
double bmm1_scale, double bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, int64_t workspace_size,
Optional<TensorView> attention_sinks) {
void trtllm_paged_attention_context(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -312,6 +337,26 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr());
}

auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>();
TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(),
"bmm1_scale must be either a double or a tensor");
TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(),
"bmm2_scale must be either a double or a tensor");
double bmm1_scale_value =
maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0;
double bmm2_scale_value =
maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0;
float* bmm1_scale_log2_ptr =
maybe_bmm1_scale_log2_tensor.has_value()
? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
: nullptr;
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
Expand All @@ -321,8 +366,9 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index,
window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream);
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count,
enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand All @@ -331,8 +377,9 @@ void trtllm_ragged_attention_launcher(
Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len,
int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale,
double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl,
bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
Expand Down Expand Up @@ -360,8 +407,12 @@ void trtllm_ragged_attention_launcher(
runner_params.mQkvLayout = QkvLayout::SeparateQkv;
runner_params.mMultiProcessorCount = sm_count;
runner_params.stream = stream;
// the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and
// outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored
runner_params.outputScale = bmm2_scale;
runner_params.outputScalePtr = bmm2_scale_ptr;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr;
runner_params.mScaleSfO = o_sf_scale;
runner_params.mChunkedAttentionSize = INT_MAX; // disable chunked attention by INT_MAX
runner_params.mAttentionWindowSize =
Expand Down Expand Up @@ -414,12 +465,12 @@ void trtllm_ragged_attention_launcher(

void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, TensorView value,
TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len,
int64_t max_kv_len, double bmm1_scale, double bmm2_scale,
double o_sf_scale, int64_t batch_size, int64_t window_left,
TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> lse) {
int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl,
bool is_causal, int64_t workspace_size,
Optional<TensorView> attention_sinks, Optional<TensorView> lse) {
float* attention_sinks_ptr = nullptr;
if (attention_sinks.has_value()) {
TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32)
Expand Down Expand Up @@ -453,15 +504,34 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
int v_stride_heads = value.stride(1);
int v_stride_batch = value.numel();

auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>();
TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(),
"bmm1_scale must be either a double or a tensor");
TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(),
"bmm2_scale must be either a double or a tensor");
double bmm1_scale_value =
maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0;
double bmm2_scale_value =
maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0;
float* bmm1_scale_log2_ptr =
maybe_bmm1_scale_log2_tensor.has_value()
? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
: nullptr;
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;
trtllm_ragged_attention_launcher(
out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(cum_seq_lens_q.data_ptr()), static_cast<int*>(cum_seq_lens_kv.data_ptr()),
attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len,
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale,
bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal,
k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads,
v_stride_batch, workspace_size, stream);
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "b793e1b2cf7c419f070372ba55bbe53ca6fb9016/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988"
)
Expand All @@ -107,7 +107,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"20c017db0761a30130f05080ed2078f6c8044c0c2b3be7c4353ec740034b4432"
"66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84"
)
TRTLLM_GEN_BMM: str = (
"85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf"
Expand Down
Loading