Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
23f284f
refactor
titaiwangms Feb 27, 2026
de61cdd
Add opset 24 nonpad_kv_seqlen tests for Attention op
titaiwangms Feb 27, 2026
17cbe1d
Fix test bugs from code review for Attention-24 nonpad_kv_seqlen
titaiwangms Feb 27, 2026
e40249e
Fix CUDA kernel review findings: GPU bounds clamping and pointer safety
titaiwangms Feb 27, 2026
bef5b9a
lint
titaiwangms Feb 27, 2026
db65ef5
resolve conflict
titaiwangms Feb 27, 2026
5b748d0
address review and use tensorscatter op in tests
titaiwangms Feb 27, 2026
e04597f
Update ONNX backend test filters for TensorScatter and nonpad_kv_seqlen
titaiwangms Feb 27, 2026
efa3ac2
Address review findings: CUDA kernel assertions, pointer safety, and …
titaiwangms Feb 27, 2026
3649935
Revert TestCase.cc change — QNN does not yet support nonpad_kv_seqlen
titaiwangms Feb 27, 2026
31aac47
Add causal mode test variants for TensorScatter + nonpad_kv_seqlen
titaiwangms Feb 28, 2026
a6713ad
Address P2 review findings: tolerance, GQA TODO, and parameter docume…
titaiwangms Feb 28, 2026
acbc057
Add batch=1 edge case test for TensorScatter attention
titaiwangms Feb 28, 2026
a221099
Merge branch 'main' into titaiwang/support_nonpad_kv_seqlen
titaiwangms Mar 2, 2026
bd468ca
update docs
titaiwangms Mar 2, 2026
9f9e768
Add GQA decode test cases to test_tensorscatter_attention.py
titaiwangms Mar 2, 2026
1d01361
Add FlashAttentionForExternalKVCache helper for TensorScatter + Atten…
titaiwangms Mar 2, 2026
2d06220
Address code review: guard header declaration and consolidate seqlens…
titaiwangms Mar 2, 2026
5a2475c
Fix critical review: OOB guard and BNSH check for flash KV cache
titaiwangms Mar 2, 2026
cc4c70e
Use ORT_MAKE_STATUS instead of ORT_ENFORCE for OOM guard
titaiwangms Mar 2, 2026
f7c3ae2
Fix GQA prompt + nonpad_kv_seqlen: skip seqlens_k conversion in promp…
titaiwangms Mar 3, 2026
4f00ed2
Fix GQA prompt + nonpad_kv_seqlen: skip seqlens_k conversion in promp…
titaiwangms Mar 3, 2026
61ac657
Restore MHA partial masking test coverage per review feedback
titaiwangms Mar 3, 2026
53bae3d
Add warning for GQA prompt + nonpad_kv_seqlen partial masking limitation
titaiwangms Mar 3, 2026
2678de7
Downgrade GQA prompt nonpad log from WARNING to VERBOSE
titaiwangms Mar 3, 2026
9be02c9
Reject GQA + prompt + nonpad_kv_seqlen with hard error on CUDA
titaiwangms Mar 3, 2026
5d187a4
Replace GQA prompt+nonpad hard error with CUDA_KERNEL_ASSERT
titaiwangms Mar 3, 2026
e4ca928
Add comment on total_seq_lens invariant in GQA prompt mode
titaiwangms Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/cpu/llm/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ inline Status ComputeOutputShapeForAttention(
TensorShape& y_shape,
TensorShape& present_key_shape,
TensorShape& present_value_shape,
TensorShape& output_qk_shape) {
TensorShape& output_qk_shape,
bool skip_nonpad_data_validation = false) {
ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
"Q, K, and V inputs must not be null");
int q_dims = onnxruntime::narrow<int>(Q->Shape().NumDimensions());
Expand Down Expand Up @@ -113,13 +114,18 @@ inline Status ComputeOutputShapeForAttention(
ORT_ENFORCE(past_key == nullptr && past_value == nullptr,
"nonpad_kv_seqlen should not be used together with past_key and past_value inputs");
parameters.has_nonpad_kv_seqlen = true;
// Warning: On CUDA, this is a device pointer. Do not dereference on host.
// See skip_nonpad_data_validation parameter — CUDA callers must set it to true.
parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data<int64_t>();
// Validate each value is in [0, total_sequence_length].
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
// Skip per-element validation when data is on GPU (CUDA provider).
if (!skip_nonpad_data_validation) {
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
}
}
} else {
parameters.has_nonpad_kv_seqlen = false;
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1590,9 +1590,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);

// Opset 23.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
Expand Down Expand Up @@ -1631,6 +1631,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention);

#endif

Expand Down Expand Up @@ -2669,9 +2672,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,

// Opset 23
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization)>,
Expand Down Expand Up @@ -2709,6 +2712,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention)>,
#endif
};

Expand Down
88 changes: 84 additions & 4 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ namespace cuda {
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
kOnnxDomain, \
24, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", BuildKernelDefConstraints<bool, T>()), \
Attention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

#undef REGISTER_KERNEL_TYPED

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Attention, \
kOnnxDomain, \
23, \
23, \
T, \
kCudaExecutionProvider, \
Expand Down Expand Up @@ -95,10 +115,16 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
y_shape,
present_key_shape,
present_value_shape,
output_qk_shape)
output_qk_shape,
true /* skip_nonpad_data_validation: data is on GPU */)
.IsOK(),
"Output shapes for Attention could not be computed.");

// Note: parameters.nonpad_kv_seqlen_data is set by ComputeOutputShapeForAttention but is a
// device pointer on CUDA — it must not be dereferenced on host. The CUDA path reads the tensor
// directly via nonpad_kv_seqlen->Data<int64_t>() when launching GPU kernels.
// Only the CPU path uses parameters.nonpad_kv_seqlen_data for per-element masking.

Tensor* Y = context->Output(0, y_shape);
Tensor* present_key = context->Output(1, present_key_shape);
Tensor* present_value = context->Output(2, present_value_shape);
Expand Down Expand Up @@ -154,6 +180,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
// GQA kernel expects K/V input sequence length == Q sequence length (self-attention only)
// Cross-attention (kv_sequence_length != q_sequence_length) is not supported
// TODO(titaiwang): This self-attention constraint prevents the TensorScatter external KV
// cache pattern from using the GQA path on CUDA, since TensorScatter provides full KV
// (kv_seq = total_seq > q_seq). Requests with nonpad_kv_seqlen and kv_seq != q_seq
// must go through the MHA path instead. Relaxing this would enable flash/memory-efficient
// attention for the external KV cache use case.
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
if (parameters.kv_sequence_length != parameters.q_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Cross-attention (kv_sequence_length != q_sequence_length) is not supported in "
Expand Down Expand Up @@ -381,7 +412,17 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from
// the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
// 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
if (parameters.has_nonpad_kv_seqlen) {
// Convert nonpad_kv_seqlen (int64, GPU) to seqlens_k (int32, GPU).
// GQA convention: seqlens_k[i] = nonpad_kv_seqlen[i] - 1 (last valid index, not count).
ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToSeqlensK(
nonpad_kv_seqlen->Data<int64_t>(),
seqlens_k_buffer.get(),
parameters.batch_size,
parameters.total_sequence_length,
cuda_stream,
device_prop.maxThreadsPerBlock));
} else if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
// Get mask dimensions for broadcasting
// attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len)
const auto& mask_shape = attn_mask->Shape();
Expand Down Expand Up @@ -568,7 +609,35 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Set additional fields
data.bias = nullptr; // New Attention op doesn't have bias
IAllocatorUniquePtr<void> converted_mask_buffer;
if (nullptr != attn_mask) {
IAllocatorUniquePtr<void> nonpad_kv_bias_buffer;
if (parameters.has_nonpad_kv_seqlen) {
if (attn_mask != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Using both nonpad_kv_seqlen and attn_mask simultaneously is not yet supported "
"in MHA path of Attention op (CUDA).");
}
// Generate attention_bias from nonpad_kv_seqlen: (B, q_seq, T) where
// position t < nonpad_kv_seqlen[b] → 0.0, position t >= nonpad_kv_seqlen[b] → -inf.
// Broadcasts over heads (broadcast_attn_bias_dim_1 = true).
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t bias_elements = static_cast<int64_t>(parameters.batch_size) *
parameters.q_sequence_length *
parameters.total_sequence_length;
nonpad_kv_bias_buffer = GetScratchBuffer<void>(bias_elements * sizeof(NativeCudaT), context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias<NativeCudaT>(
nonpad_kv_seqlen->Data<int64_t>(),
reinterpret_cast<NativeCudaT*>(nonpad_kv_bias_buffer.get()),
parameters.batch_size,
parameters.q_sequence_length,
parameters.total_sequence_length,
contribop_parameters.mask_filter_value,
cuda_stream,
GetDeviceProp().maxThreadsPerBlock));
data.attention_bias = reinterpret_cast<const CudaT*>(nonpad_kv_bias_buffer.get());
contribop_parameters.broadcast_attn_bias_dim_0 = false;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else if (nullptr != attn_mask) {
if (attn_mask->IsDataType<bool>()) {
// Convert boolean mask to additive attention bias: true -> 0.0, false -> mask_filter_value.
// The conversion is element-wise and preserves the original shape, so the broadcast flags
Expand All @@ -591,7 +660,18 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
data.qkv_format = contribop_parameters.qkv_format;

// For now, set flags to false and let QkvToContext use the unfused path
// TODO(titaiwang): Enable memory-efficient or flash attention for MHA + nonpad_kv_seqlen.
//
// Currently forces unfused O(n²) attention because:
// - nonpad_kv_seqlen is converted to attention_bias (B, q_seq, total_seq)
// - Flash attention requires attention_bias == nullptr (hard API constraint)
// - Memory-efficient attention supports attention_bias (with alignment) but
// is conservatively disabled pending validation
//
// This is safe for decode (q_seq=1) where the unfused path is cheap.
// WARNING: For prefill with large q_seq, unfused attention may OOM.
// Follow-up: enable memory-efficient attention, or add seqlens_k-style masking
// to MHA path (like GQA) to bypass attention_bias and enable flash attention.
data.use_flash_attention = false;
data.use_memory_efficient_attention = false;
data.fused_runner = nullptr;
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,115 @@
template Status LaunchConvertBoolMaskToAttentionBias<__nv_bfloat16>(
const bool*, __nv_bfloat16*, int64_t, float, cudaStream_t, int);

// CUDA kernel to convert nonpad_kv_seqlen (int64) to seqlens_k (int32) for GQA.
// GQA convention: seqlens_k = nonpad_kv_seqlen - 1 (last valid index, not count).
//
// IMPORTANT: nonpad_kv_seqlen must be >= 1 for every batch element.
// A value of 0 would produce seqlens_k=0, which GQA interprets as "1 valid token at
// position 0" (last-valid-index convention), causing silent attention to garbage data.
// Callers must ensure sequences are non-empty before invoking this kernel.
//
// Validation (via CUDA_KERNEL_ASSERT, reported asynchronously):
// - val must be > 0 (nonpad_kv_seqlen=0 would silently corrupt output)
// - val must be <= total_sequence_length (out of bounds)
__global__ void ConvertNonpadKvSeqlenToSeqlensKKernel(
const int64_t* __restrict__ nonpad_kv_seqlen,
int* __restrict__ seqlens_k,
const int batch_size,
const int total_sequence_length) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size) {
int64_t val = nonpad_kv_seqlen[idx];
CUDA_KERNEL_ASSERT(val > 0); // nonpad_kv_seqlen=0 → seqlens_k=0 → attends to garbage at pos 0
CUDA_KERNEL_ASSERT(val <= static_cast<int64_t>(total_sequence_length));
// Clamp to [1, total_sequence_length] for safety in release builds where asserts are no-ops.
val = max(static_cast<int64_t>(1), min(val, static_cast<int64_t>(total_sequence_length)));
seqlens_k[idx] = static_cast<int>(val) - 1;
Comment thread
titaiwangms marked this conversation as resolved.
}
}

Status LaunchConvertNonpadKvSeqlenToSeqlensK(
const int64_t* nonpad_kv_seqlen,
int* seqlens_k,
int batch_size,
int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block) {
if (batch_size == 0) {
return Status::OK();
}

// Note: The kernel uses CUDA_KERNEL_ASSERT for GPU-side validation (debug builds only).
// In release builds, the kernel defensively clamps inputs to valid ranges.
int threads = std::min(batch_size, max_threads_per_block);
int blocks = (batch_size + threads - 1) / threads;

ConvertNonpadKvSeqlenToSeqlensKKernel<<<blocks, threads, 0, stream>>>(
nonpad_kv_seqlen, seqlens_k, batch_size, total_sequence_length);

return CUDA_CALL(cudaGetLastError());
}

// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias.
// Generates (batch_size, q_seq_len, total_seq_len) output where:
// position t < nonpad_kv_seqlen[b] → 0.0 (attend)
// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out)
// The same mask row is repeated for each query position within a batch.
template <typename T>
__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel(
const int64_t* __restrict__ nonpad_kv_seqlen,
T* __restrict__ attention_bias,
const int batch_size,
const int q_seq_len,
const int total_seq_len,
const float mask_filter_value) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t total = static_cast<int64_t>(batch_size) * q_seq_len * total_seq_len;
for (; idx < total; idx += static_cast<int64_t>(gridDim.x) * blockDim.x) {
int b = static_cast<int>(idx / (static_cast<int64_t>(q_seq_len) * total_seq_len));
int t = static_cast<int>(idx % total_seq_len);
int64_t valid_len = nonpad_kv_seqlen[b];
// Note: valid_len=0 is allowed here (masks all positions), unlike the seqlens_k kernel
// where 0 would be misinterpreted by GQA's last-valid-index convention.
CUDA_KERNEL_ASSERT(valid_len >= 0 && valid_len <= static_cast<int64_t>(total_seq_len));
// Clamp to [0, total_seq_len] for safety in release builds where asserts are no-ops.
valid_len = max(static_cast<int64_t>(0), min(valid_len, static_cast<int64_t>(total_seq_len)));
attention_bias[idx] = (t < static_cast<int>(valid_len)) ? T(0.0f) : T(mask_filter_value);
}
}

template <typename T>
Status LaunchConvertNonpadKvSeqlenToAttentionBias(
const int64_t* nonpad_kv_seqlen,
T* attention_bias,
int batch_size,
int q_seq_len,
int total_seq_len,
float mask_filter_value,
cudaStream_t stream,
int max_threads_per_block) {
int64_t total = static_cast<int64_t>(batch_size) * q_seq_len * total_seq_len;
if (total == 0) {
return Status::OK();
}

int threads = static_cast<int>(std::min(static_cast<int64_t>(max_threads_per_block), total));
int64_t blocks = (total + threads - 1) / threads;
constexpr int64_t kMaxGridDimX = 65535;
unsigned int grid_size = static_cast<unsigned int>(std::min(blocks, kMaxGridDimX));

Check warning on line 271 in onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:271: Add #include <algorithm> for min [build/include_what_you_use] [4]

ConvertNonpadKvSeqlenToAttentionBiasKernel<T><<<grid_size, threads, 0, stream>>>(
nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value);

return CUDA_CALL(cudaGetLastError());
}

template Status LaunchConvertNonpadKvSeqlenToAttentionBias<float>(
const int64_t*, float*, int, int, int, float, cudaStream_t, int);
template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>(
const int64_t*, __half*, int, int, int, float, cudaStream_t, int);
template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>(
const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int);

} // namespace cuda
} // namespace onnxruntime
Loading
Loading