-
Notifications
You must be signed in to change notification settings - Fork 575
[API change] Allow using torch.Tensor for scales for trtllm-gen attention #2084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
e3ce4cf
1c8202d
3e6bb28
2f39e1f
d8f6387
d2e992f
815da6b
c210044
270087d
997b913
74d153f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| #include <flashinfer/trtllm/fmha/decoder_impl_common.h> | ||
| #include <flashinfer/trtllm/fmha/fmhaRunnerParams.h> | ||
| #include <nvrtc.h> | ||
| #include <tvm/ffi/container/variant.h> | ||
|
|
||
| #include <flashinfer/trtllm/fmha/fmhaRunner.cuh> | ||
| #include <flashinfer/utils.cuh> | ||
|
|
@@ -28,6 +29,7 @@ | |
| #include "tvm_ffi_utils.h" | ||
|
|
||
| using tvm::ffi::Optional; | ||
| using tvm::ffi::Variant; | ||
|
|
||
| namespace flashinfer { | ||
|
|
||
|
|
@@ -78,9 +80,10 @@ void trtllm_paged_attention_launcher( | |
| int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, | ||
| int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t kv_stride_keys_values, | ||
| int64_t kv_stride_heads, int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, | ||
| double bmm1_scale, double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, | ||
| int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sm_count, | ||
| bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { | ||
| double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, | ||
| const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, | ||
| int64_t window_left, int64_t sum_seq_q, int64_t sm_count, bool enable_pdl, | ||
| int64_t workspace_size, cudaStream_t stream) { | ||
| if (num_qo_heads % num_kv_heads != 0) { | ||
| std::ostringstream err_msg; | ||
| err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads | ||
|
|
@@ -117,8 +120,12 @@ void trtllm_paged_attention_launcher( | |
| runner_params.vStrideBatch = kv_stride_batch; | ||
| runner_params.mNumPagesInMemPool = num_pages_in_mem_pool; | ||
| runner_params.stream = stream; | ||
| // the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and | ||
| // outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored | ||
| runner_params.outputScale = bmm2_scale; | ||
| runner_params.outputScalePtr = bmm2_scale_ptr; | ||
| runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; | ||
| runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; | ||
| runner_params.oSfPtr = out_scale_factor; | ||
| runner_params.mSfStartTokenIdx = o_sf_start_index; | ||
| runner_params.mScaleSfO = o_sf_scale; | ||
|
|
@@ -197,11 +204,12 @@ inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_T | |
| void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor, | ||
| TensorView query, TensorView key_cache, TensorView value_cache, | ||
| TensorView workspace_buffer, TensorView block_tables, | ||
| TensorView seq_lens, int64_t max_kv_len, double bmm1_scale, | ||
| double bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, | ||
| int64_t o_sf_start_index, int64_t window_left, int64_t sm_count, | ||
| bool enable_pdl, int64_t workspace_size, | ||
| Optional<TensorView> attention_sinks) { | ||
| TensorView seq_lens, int64_t max_kv_len, | ||
| Variant<double, ffi::Tensor> bmm1_scale, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @tqchen , I suppose Variant is a legit ABI across languages right? |
||
| Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, | ||
| int64_t o_sf_vec_size, int64_t o_sf_start_index, | ||
| int64_t window_left, int64_t sm_count, bool enable_pdl, | ||
| int64_t workspace_size, Optional<TensorView> attention_sinks) { | ||
| auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); | ||
| auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); | ||
| TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); | ||
|
|
@@ -250,7 +258,25 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal | |
| << "attention_sinks must be a float tensor"; | ||
| attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr()); | ||
| } | ||
|
|
||
| auto maybe_bmm1_scale_value = bmm1_scale.as<double>(); | ||
| auto maybe_bmm2_scale_value = bmm2_scale.as<double>(); | ||
| auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>(); | ||
| auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>(); | ||
| TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), | ||
| "bmm1_scale must be either a double or a tensor"); | ||
| TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), | ||
| "bmm2_scale must be either a double or a tensor"); | ||
| double bmm1_scale_value = | ||
| maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; | ||
| double bmm2_scale_value = | ||
| maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; | ||
| float* bmm1_scale_log2_ptr = | ||
| maybe_bmm1_scale_log2_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr()) | ||
| : nullptr; | ||
|
Comment on lines
+262
to
+279
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard tensor-based scales with dtype checks When @@
- float* bmm1_scale_log2_ptr =
- maybe_bmm1_scale_log2_tensor.has_value()
- ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
- : nullptr;
- float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
- ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
- : nullptr;
+ float* bmm1_scale_log2_ptr = nullptr;
+ if (maybe_bmm1_scale_log2_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm1_scale_log2_tensor.value().dtype(), dl_float32)
+ << "bmm1_scale tensor must be float32";
+ bmm1_scale_log2_ptr =
+ static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr());
+ }
+ float* bmm2_scale_ptr = nullptr;
+ if (maybe_bmm2_scale_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm2_scale_tensor.value().dtype(), dl_float32)
+ << "bmm2_scale tensor must be float32";
+ bmm2_scale_ptr =
+ static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr());
+ }Please mirror this guard in Also applies to: 338-356, 503-521 π€ Prompt for AI Agents |
||
| trtllm_paged_attention_launcher( | ||
| out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), | ||
| workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), | ||
|
|
@@ -259,21 +285,20 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal | |
| /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, | ||
| TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, | ||
| num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, | ||
| kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale, | ||
| bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, | ||
| enable_pdl, workspace_size, stream); | ||
| kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, | ||
| bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, | ||
| o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, | ||
| stream); | ||
| } | ||
|
|
||
| void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_scale_factor, | ||
| TensorView query, TensorView key_cache, TensorView value_cache, | ||
| TensorView workspace_buffer, TensorView block_tables, | ||
| TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, | ||
| double bmm1_scale, double bmm2_scale, double o_sf_scale, | ||
| int64_t o_sf_vec_size, int64_t o_sf_start_index, | ||
| int64_t batch_size, int64_t window_left, | ||
| TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, | ||
| int64_t sm_count, bool enable_pdl, int64_t workspace_size, | ||
| Optional<TensorView> attention_sinks) { | ||
| void trtllm_paged_attention_context( | ||
| TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache, | ||
| TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, | ||
| TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, | ||
| Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale, | ||
| double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, | ||
| int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, | ||
| bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) { | ||
| auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); | ||
| auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); | ||
| auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); | ||
|
|
@@ -312,6 +337,26 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca | |
| attention_sinks_ptr = static_cast<float*>(attention_sinks.value().data_ptr()); | ||
| } | ||
|
|
||
| auto maybe_bmm1_scale_value = bmm1_scale.as<double>(); | ||
| auto maybe_bmm2_scale_value = bmm2_scale.as<double>(); | ||
| auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>(); | ||
| auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>(); | ||
| TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), | ||
| "bmm1_scale must be either a double or a tensor"); | ||
| TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), | ||
| "bmm2_scale must be either a double or a tensor"); | ||
| double bmm1_scale_value = | ||
| maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; | ||
| double bmm2_scale_value = | ||
| maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; | ||
| float* bmm1_scale_log2_ptr = | ||
| maybe_bmm1_scale_log2_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr()) | ||
| : nullptr; | ||
|
|
||
| trtllm_paged_attention_launcher( | ||
| out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), | ||
| workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), | ||
|
|
@@ -321,8 +366,9 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca | |
| q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, | ||
| max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, | ||
| head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, | ||
| max_num_blocks_per_seq, bmm1_scale, bmm2_scale, o_sf_scale, o_sf_vec_size, o_sf_start_index, | ||
| window_left, sum_seq_q, sm_count, enable_pdl, workspace_size, stream); | ||
| max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, | ||
| bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sm_count, | ||
| enable_pdl, workspace_size, stream); | ||
| } | ||
|
|
||
| void trtllm_ragged_attention_launcher( | ||
|
|
@@ -331,8 +377,9 @@ void trtllm_ragged_attention_launcher( | |
| Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len, | ||
| int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, | ||
| int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale, | ||
| double o_sf_scale, int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, | ||
| bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, | ||
| const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, | ||
| int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal, | ||
| int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, | ||
| int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch, | ||
| int64_t workspace_size, cudaStream_t stream) { | ||
| if (num_qo_heads % num_kv_heads != 0) { | ||
|
|
@@ -360,8 +407,12 @@ void trtllm_ragged_attention_launcher( | |
| runner_params.mQkvLayout = QkvLayout::SeparateQkv; | ||
| runner_params.mMultiProcessorCount = sm_count; | ||
| runner_params.stream = stream; | ||
| // the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and | ||
| // outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored | ||
| runner_params.outputScale = bmm2_scale; | ||
| runner_params.outputScalePtr = bmm2_scale_ptr; | ||
| runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E; | ||
| runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr; | ||
| runner_params.mScaleSfO = o_sf_scale; | ||
| runner_params.mChunkedAttentionSize = INT_MAX; // disable chunked attention by INT_MAX | ||
| runner_params.mAttentionWindowSize = | ||
|
|
@@ -414,12 +465,12 @@ void trtllm_ragged_attention_launcher( | |
|
|
||
| void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, TensorView value, | ||
| TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len, | ||
| int64_t max_kv_len, double bmm1_scale, double bmm2_scale, | ||
| double o_sf_scale, int64_t batch_size, int64_t window_left, | ||
| TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, | ||
| int64_t sm_count, bool enable_pdl, bool is_causal, | ||
| int64_t workspace_size, Optional<TensorView> attention_sinks, | ||
| Optional<TensorView> lse) { | ||
| int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale, | ||
| Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale, | ||
| int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, | ||
| TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, | ||
| bool is_causal, int64_t workspace_size, | ||
| Optional<TensorView> attention_sinks, Optional<TensorView> lse) { | ||
| float* attention_sinks_ptr = nullptr; | ||
| if (attention_sinks.has_value()) { | ||
| TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) | ||
|
|
@@ -453,15 +504,34 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T | |
| int v_stride_heads = value.stride(1); | ||
| int v_stride_batch = value.numel(); | ||
|
|
||
| auto maybe_bmm1_scale_value = bmm1_scale.as<double>(); | ||
| auto maybe_bmm2_scale_value = bmm2_scale.as<double>(); | ||
| auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>(); | ||
| auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>(); | ||
| TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), | ||
| "bmm1_scale must be either a double or a tensor"); | ||
| TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), | ||
| "bmm2_scale must be either a double or a tensor"); | ||
| double bmm1_scale_value = | ||
| maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; | ||
| double bmm2_scale_value = | ||
| maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; | ||
| float* bmm1_scale_log2_ptr = | ||
| maybe_bmm1_scale_log2_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| trtllm_ragged_attention_launcher( | ||
| out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), | ||
| workspace_buffer.data_ptr(), static_cast<int*>(seq_lens.data_ptr()), | ||
| static_cast<int*>(cum_seq_lens_q.data_ptr()), static_cast<int*>(cum_seq_lens_kv.data_ptr()), | ||
| attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len, | ||
| num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, | ||
| bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, | ||
| k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, | ||
| v_stride_batch, workspace_size, stream); | ||
| num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value, | ||
| bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left, | ||
| sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, | ||
| v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); | ||
| } | ||
|
|
||
| namespace trtllm_cubin_loader { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.