Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
23f284f
refactor
titaiwangms Feb 27, 2026
de61cdd
Add opset 24 nonpad_kv_seqlen tests for Attention op
titaiwangms Feb 27, 2026
17cbe1d
Fix test bugs from code review for Attention-24 nonpad_kv_seqlen
titaiwangms Feb 27, 2026
e40249e
Fix CUDA kernel review findings: GPU bounds clamping and pointer safety
titaiwangms Feb 27, 2026
bef5b9a
lint
titaiwangms Feb 27, 2026
db65ef5
resolve conflict
titaiwangms Feb 27, 2026
5b748d0
address review and use tensorscatter op in tests
titaiwangms Feb 27, 2026
e04597f
Update ONNX backend test filters for TensorScatter and nonpad_kv_seqlen
titaiwangms Feb 27, 2026
efa3ac2
Address review findings: CUDA kernel assertions, pointer safety, and …
titaiwangms Feb 27, 2026
3649935
Revert TestCase.cc change — QNN does not yet support nonpad_kv_seqlen
titaiwangms Feb 27, 2026
31aac47
Add causal mode test variants for TensorScatter + nonpad_kv_seqlen
titaiwangms Feb 28, 2026
a6713ad
Address P2 review findings: tolerance, GQA TODO, and parameter docume…
titaiwangms Feb 28, 2026
acbc057
Add batch=1 edge case test for TensorScatter attention
titaiwangms Feb 28, 2026
a221099
Merge branch 'main' into titaiwang/support_nonpad_kv_seqlen
titaiwangms Mar 2, 2026
bd468ca
update docs
titaiwangms Mar 2, 2026
9f9e768
Add GQA decode test cases to test_tensorscatter_attention.py
titaiwangms Mar 2, 2026
1d01361
Add FlashAttentionForExternalKVCache helper for TensorScatter + Atten…
titaiwangms Mar 2, 2026
2d06220
Address code review: guard header declaration and consolidate seqlens…
titaiwangms Mar 2, 2026
5a2475c
Fix critical review: OOB guard and BNSH check for flash KV cache
titaiwangms Mar 2, 2026
cc4c70e
Use ORT_MAKE_STATUS instead of ORT_ENFORCE for OOM guard
titaiwangms Mar 2, 2026
f7c3ae2
Fix GQA prompt + nonpad_kv_seqlen: skip seqlens_k conversion in promp…
titaiwangms Mar 3, 2026
4f00ed2
Fix GQA prompt + nonpad_kv_seqlen: skip seqlens_k conversion in promp…
titaiwangms Mar 3, 2026
61ac657
Restore MHA partial masking test coverage per review feedback
titaiwangms Mar 3, 2026
53bae3d
Add warning for GQA prompt + nonpad_kv_seqlen partial masking limitation
titaiwangms Mar 3, 2026
2678de7
Downgrade GQA prompt nonpad log from WARNING to VERBOSE
titaiwangms Mar 3, 2026
9be02c9
Reject GQA + prompt + nonpad_kv_seqlen with hard error on CUDA
titaiwangms Mar 3, 2026
5d187a4
Replace GQA prompt+nonpad hard error with CUDA_KERNEL_ASSERT
titaiwangms Mar 3, 2026
e4ca928
Add comment on total_seq_lens invariant in GQA prompt mode
titaiwangms Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions onnxruntime/core/providers/cpu/llm/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ inline Status ComputeOutputShapeForAttention(
TensorShape& y_shape,
TensorShape& present_key_shape,
TensorShape& present_value_shape,
TensorShape& output_qk_shape) {
TensorShape& output_qk_shape,
bool skip_nonpad_data_validation = false) {
ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
"Q, K, and V inputs must not be null");
int q_dims = onnxruntime::narrow<int>(Q->Shape().NumDimensions());
Expand Down Expand Up @@ -113,13 +114,17 @@ inline Status ComputeOutputShapeForAttention(
ORT_ENFORCE(past_key == nullptr && past_value == nullptr,
"nonpad_kv_seqlen should not be used together with past_key and past_value inputs");
parameters.has_nonpad_kv_seqlen = true;
// Note: This pointer is CPU-accessible only. CUDA path should not dereference this directly.
parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data<int64_t>();
// Validate each value is in [0, total_sequence_length].
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
// Skip per-element validation when data is on GPU (CUDA provider).
if (!skip_nonpad_data_validation) {
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
}
}
} else {
parameters.has_nonpad_kv_seqlen = false;
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1590,9 +1590,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);

// Opset 23.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
Expand Down Expand Up @@ -1631,6 +1631,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention);

#endif

Expand Down Expand Up @@ -2669,9 +2672,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,

// Opset 23
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization)>,
Expand Down Expand Up @@ -2709,6 +2712,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention)>,
#endif
};

Expand Down
65 changes: 62 additions & 3 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,29 @@ namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
kOnnxDomain, \
24, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("U", BuildKernelDefConstraints<bool, T>()), \
Attention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

#undef REGISTER_KERNEL_TYPED

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Attention, \
kOnnxDomain, \
23, \
23, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
Expand Down Expand Up @@ -95,7 +115,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
y_shape,
present_key_shape,
present_value_shape,
output_qk_shape)
output_qk_shape,
true /* skip_nonpad_data_validation: data is on GPU */)
.IsOK(),
"Output shapes for Attention could not be computed.");

Expand Down Expand Up @@ -381,7 +402,17 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// masks to seqlens_k directly (bypassing ONNX right-aligned broadcasting). This differs from
// the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
// 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
if (parameters.has_nonpad_kv_seqlen) {
// Convert nonpad_kv_seqlen (int64, GPU) to seqlens_k (int32, GPU).
// GQA convention: seqlens_k[i] = nonpad_kv_seqlen[i] - 1 (last valid index, not count).
ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToSeqlensK(
nonpad_kv_seqlen->Data<int64_t>(),
seqlens_k_buffer.get(),
parameters.batch_size,
parameters.total_sequence_length,
cuda_stream,
device_prop.maxThreadsPerBlock));
} else if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
// Get mask dimensions for broadcasting
// attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len)
const auto& mask_shape = attn_mask->Shape();
Expand Down Expand Up @@ -568,7 +599,35 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Set additional fields
data.bias = nullptr; // New Attention op doesn't have bias
IAllocatorUniquePtr<void> converted_mask_buffer;
if (nullptr != attn_mask) {
IAllocatorUniquePtr<void> nonpad_kv_bias_buffer;
if (parameters.has_nonpad_kv_seqlen) {
if (attn_mask != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Using both nonpad_kv_seqlen and attn_mask simultaneously is not yet supported "
"in MHA path of Attention op (CUDA).");
}
// Generate attention_bias from nonpad_kv_seqlen: (B, q_seq, T) where
// position t < nonpad_kv_seqlen[b] → 0.0, position t >= nonpad_kv_seqlen[b] → -inf.
// Broadcasts over heads (broadcast_attn_bias_dim_1 = true).
using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType<T>::type;
int64_t bias_elements = static_cast<int64_t>(parameters.batch_size) *
parameters.q_sequence_length *
parameters.total_sequence_length;
nonpad_kv_bias_buffer = GetScratchBuffer<void>(bias_elements * sizeof(NativeCudaT), context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias<NativeCudaT>(
nonpad_kv_seqlen->Data<int64_t>(),
reinterpret_cast<NativeCudaT*>(nonpad_kv_bias_buffer.get()),
parameters.batch_size,
parameters.q_sequence_length,
parameters.total_sequence_length,
contribop_parameters.mask_filter_value,
cuda_stream,
GetDeviceProp().maxThreadsPerBlock));
data.attention_bias = reinterpret_cast<const CudaT*>(nonpad_kv_bias_buffer.get());
contribop_parameters.broadcast_attn_bias_dim_0 = false;
contribop_parameters.broadcast_attn_bias_dim_1 = true;
} else if (nullptr != attn_mask) {
if (attn_mask->IsDataType<bool>()) {
// Convert boolean mask to additive attention bias: true -> 0.0, false -> mask_filter_value.
// The conversion is element-wise and preserves the original shape, so the broadcast flags
Expand Down
99 changes: 99 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,104 @@
template Status LaunchConvertBoolMaskToAttentionBias<__nv_bfloat16>(
const bool*, __nv_bfloat16*, int64_t, float, cudaStream_t, int);

// CUDA kernel to convert nonpad_kv_seqlen (int64) to seqlens_k (int32) for GQA.
// GQA convention: seqlens_k = nonpad_kv_seqlen - 1 (last valid index, not count).
// A value of 0 in seqlens_k represents an empty KV sequence for that batch element.
__global__ void ConvertNonpadKvSeqlenToSeqlensKKernel(
const int64_t* __restrict__ nonpad_kv_seqlen,
int* __restrict__ seqlens_k,
const int batch_size,
const int total_sequence_length) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size) {
int64_t val = nonpad_kv_seqlen[idx];
// Clamp to valid range [0, total_sequence_length] before int64→int32 cast.
val = max(static_cast<int64_t>(0), min(val, static_cast<int64_t>(total_sequence_length)));
int seqlen = static_cast<int>(val) - 1;
// Clamp to non-negative so that 0 cleanly represents an empty KV sequence.
if (seqlen < 0) {
seqlen = 0;
}
seqlens_k[idx] = seqlen;
}
}

Status LaunchConvertNonpadKvSeqlenToSeqlensK(
const int64_t* nonpad_kv_seqlen,
int* seqlens_k,
int batch_size,
int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block) {
if (batch_size == 0) {
return Status::OK();
}

int threads = std::min(batch_size, max_threads_per_block);
int blocks = (batch_size + threads - 1) / threads;

ConvertNonpadKvSeqlenToSeqlensKKernel<<<blocks, threads, 0, stream>>>(
nonpad_kv_seqlen, seqlens_k, batch_size, total_sequence_length);

return CUDA_CALL(cudaGetLastError());
}

// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias.
// Generates (batch_size, q_seq_len, total_seq_len) output where:
// position t < nonpad_kv_seqlen[b] → 0.0 (attend)
// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out)
// The same mask row is repeated for each query position within a batch.
template <typename T>
__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel(
const int64_t* __restrict__ nonpad_kv_seqlen,
T* __restrict__ attention_bias,
const int batch_size,
const int q_seq_len,
const int total_seq_len,
const float mask_filter_value) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
int64_t total = static_cast<int64_t>(batch_size) * q_seq_len * total_seq_len;
for (; idx < total; idx += static_cast<int64_t>(gridDim.x) * blockDim.x) {
int b = static_cast<int>(idx / (static_cast<int64_t>(q_seq_len) * total_seq_len));
int t = static_cast<int>(idx % total_seq_len);
// Clamp nonpad_kv_seqlen to [0, total_seq_len] for safety.
int64_t valid_len = max(static_cast<int64_t>(0), min(nonpad_kv_seqlen[b], static_cast<int64_t>(total_seq_len)));
attention_bias[idx] = (t < static_cast<int>(valid_len)) ? T(0.0f) : T(mask_filter_value);
}
}

template <typename T>
Status LaunchConvertNonpadKvSeqlenToAttentionBias(
const int64_t* nonpad_kv_seqlen,
T* attention_bias,
int batch_size,
int q_seq_len,
int total_seq_len,
float mask_filter_value,
cudaStream_t stream,
int max_threads_per_block) {
int64_t total = static_cast<int64_t>(batch_size) * q_seq_len * total_seq_len;
if (total == 0) {
return Status::OK();
}

int threads = static_cast<int>(std::min(static_cast<int64_t>(max_threads_per_block), total));
int64_t blocks = (total + threads - 1) / threads;
constexpr int64_t kMaxGridDimX = 65535;
unsigned int grid_size = static_cast<unsigned int>(std::min(blocks, kMaxGridDimX));

Check warning on line 260 in onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:260: Add #include <algorithm> for min [build/include_what_you_use] [4]

ConvertNonpadKvSeqlenToAttentionBiasKernel<T><<<grid_size, threads, 0, stream>>>(
nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value);

return CUDA_CALL(cudaGetLastError());
}

template Status LaunchConvertNonpadKvSeqlenToAttentionBias<float>(
const int64_t*, float*, int, int, int, float, cudaStream_t, int);
template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>(
const int64_t*, __half*, int, int, int, float, cudaStream_t, int);
template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>(
const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int);

} // namespace cuda
} // namespace onnxruntime
46 changes: 46 additions & 0 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,51 @@ Status LaunchConvertBoolMaskToAttentionBias(
cudaStream_t stream,
int max_threads_per_block);

// Convert nonpad_kv_seqlen (int64, per-batch valid KV lengths) to seqlens_k (int32) for GQA.
// GQA convention: seqlens_k[i] = nonpad_kv_seqlen[i] - 1 (last valid index, not count).
//
// Parameters:
// nonpad_kv_seqlen: Input int64 tensor on GPU, shape [batch_size]
// seqlens_k: Output int32 buffer on GPU, shape [batch_size]
// batch_size: Number of batches
// total_sequence_length: Max KV sequence length (for bounds clamping)
// stream: CUDA stream
// max_threads_per_block: Maximum threads per block
Status LaunchConvertNonpadKvSeqlenToSeqlensK(
const int64_t* nonpad_kv_seqlen,
int* seqlens_k,
int batch_size,
int total_sequence_length,
cudaStream_t stream,
int max_threads_per_block);
Comment on lines +88 to +97

Copilot AI Mar 3, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The header docs for LaunchConvertNonpadKvSeqlenToFlashSeqlensK don’t clarify the allowed range, and the implementation currently treats 0 as invalid. Since nonpad_kv_seqlen is a count, 0 should be representable (empty KV) in the flash path. Please update the contract here to explicitly allow 0 (and align with the implementation change to clamp/assert on [0, total_sequence_length]).

Copilot uses AI. Check for mistakes.

// Convert nonpad_kv_seqlen to an additive attention bias for the MHA unfused path.
// Generates a (batch_size, q_seq_len, total_seq_len) tensor where:
// position t < nonpad_kv_seqlen[b] → 0.0 (attend)
// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out)
//
// The output is used as attention_bias with broadcast_attn_bias_dim_0=false,
// broadcast_attn_bias_dim_1=true (broadcasts over heads).
//
// Parameters:
// nonpad_kv_seqlen: Input int64 tensor on GPU, shape [batch_size]
// attention_bias: Output buffer on GPU, shape [batch_size * q_seq_len * total_seq_len]
// batch_size: Number of batches
// q_seq_len: Query sequence length
// total_seq_len: Total KV sequence length
// mask_filter_value: Value for masked positions (typically -inf)
// stream: CUDA stream
// max_threads_per_block: Maximum threads per block
template <typename T>
Status LaunchConvertNonpadKvSeqlenToAttentionBias(
const int64_t* nonpad_kv_seqlen,
T* attention_bias,
int batch_size,
int q_seq_len,
int total_seq_len,
float mask_filter_value,
cudaStream_t stream,
int max_threads_per_block);

} // namespace cuda
} // namespace onnxruntime
Loading
Loading