Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
020ac15
ONNX Attention thin-dispatcher: direct flash/MEA/unfused dispatch wit…
titaiwangms Mar 4, 2026
e270a6a
Update Operator Kernel document
titaiwangms Mar 4, 2026
6c50092
Fix cutlass FMHA crash when attention bias stride is unaligned
titaiwangms Mar 4, 2026
6ac08bf
Fix GQA decode eligibility, padding mask wiring, and 4D BNSH Q transpose
titaiwangms Mar 4, 2026
f198af2
Fix all-false mask crash in ConvertMaskToSeqlensKernel
titaiwangms Mar 4, 2026
9eb11e4
Replace host memcpy with device-side fill for CUDA graph capture
titaiwangms Mar 4, 2026
16e8b65
Handle bool mask in decode path to support variable padding
titaiwangms Mar 4, 2026
d3b4968
Fix NaN output for all-false bool masks in MEA path
titaiwangms Mar 4, 2026
e0760a8
Fix cutlass FMHA bias alignment crash for unaligned kv_sequence_length
titaiwangms Mar 4, 2026
a827a1b
Fix 3D mask test to use consistent per-batch padding semantics
titaiwangms Mar 4, 2026
9a0b546
Remove 11 redundant GQA tests from test_gqa.py
titaiwangms Mar 4, 2026
939a08c
Fix padding mask bugs: zero present buffers, decode offset, MEA 2D ex…
titaiwangms Mar 4, 2026
118546d
Add TODO for GQA unfused attention fallback
titaiwangms Mar 4, 2026
b8ea59e
Add TODO comments for GQA+float_mask and 4D present gaps
titaiwangms Mar 4, 2026
aca1cf8
Add TODO comments for softcap/softmax_precision and output_qk gaps
titaiwangms Mar 4, 2026
0990076
Revert "Add TODO comments for GQA+float_mask and 4D present gaps"
titaiwangms Mar 4, 2026
cb64751
Code cleanup: remove dead function, fix comments, CUDA-graph-safe 2D …
titaiwangms Mar 4, 2026
76b006a
Add test improvements: unfused MHA, 4D BNSH GQA, broadcast mask, floa…
titaiwangms Mar 4, 2026
c3f771a
lint
titaiwangms Mar 4, 2026
d6f16af
Fix 2D mask shape, add 4D BNSH present_kv, cleanup and docs
titaiwangms Mar 5, 2026
813dae2
Fix 2D mask shape, add 4D BNSH present_kv, cleanup and docs
titaiwangms Mar 5, 2026
68a1b02
Address PR review feedback: transpose helpers, assert fixes, SEGFAULT…
titaiwangms Mar 5, 2026
27ee9af
Refine comments, fix docstrings, and remove dead code
titaiwangms Mar 5, 2026
381cd83
Add clarifying comment for DispatchIsAligned bias alignment check
titaiwangms Mar 5, 2026
a158251
Fix SM skip thresholds in attention tests (T25)
titaiwangms Mar 6, 2026
73da07a
Address Copilot review: env var support, SM skip fix, nonpad+mask fal…
titaiwangms Mar 6, 2026
1e01894
Wire up nonpad_kv_seqlen + attn_mask composition in unfused path (T28)
titaiwangms Mar 6, 2026
31567e7
Fix 2D mask shape in GQA tests and add mask validation (T29)
titaiwangms Mar 6, 2026
c5a1ebb
Fix T28 review issues: guard mask dims, prevent divide-by-zero, fix s…
titaiwangms Mar 6, 2026
7111a4d
Add Python tests for nonpad_kv_seqlen + attn_mask combination (T31)
titaiwangms Mar 6, 2026
f8cd689
Address review feedback: BF16 fix, unfused nonpad+mask, test improvem…
titaiwangms Mar 6, 2026
4b87dd8
Fix test failures: reference mask shape, seqlens size, invalid configs
titaiwangms Mar 6, 2026
a6dce1a
Validate present_key/present_value outputs in TensorScatter attention…
titaiwangms Mar 6, 2026
a16baab
Address Copilot review round 3: log level, present_kv validation, cle…
titaiwangms Mar 6, 2026
cb2ae8c
Fix env var leak in tests: restore ORT_DISABLE_* after use
titaiwangms Mar 7, 2026
579af5f
lint
titaiwangms Mar 7, 2026
46570e5
Merge branch 'main' into titaiwang/design_attention_with_ai
titaiwangms Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,8 @@ Do not modify directly.*
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|24+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|||23|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block);

// BxNxSxH => BxSxNxH
Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const float* input, float* output, cudaStream_t stream, const int max_threads_per_block);

Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);

Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block);

template <typename T>
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
int sequence_length, int total_sequence_length,
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,26 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c
max_threads_per_block, false, input, output);
}

// BxNxSxH => BxSxNxH (BNSH to BSNH) — reverse of Transpose_BSNH_to_BNSH.
// Reuses the existing TransposeCtx kernel which does exactly this transformation.
Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const float* input, float* output, cudaStream_t stream, const int max_threads_per_block) {
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, input, output);
}

Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) {
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, input, output);
}

Status Transpose_BNSH_to_BSNH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block) {
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, input, output);
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,33 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;

// Bias stride alignment check: route to the unaligned kernel when bias strides
// don't satisfy the aligned kernel's kAlignmentQ requirement.
//
// kAlignmentQ is template-dependent (kernel_forward.h:414):
// isAligned=true: kAlignmentQ = DefaultConfig::kAlignmentA (8 for fp16/bf16 SM75+)
// isAligned=false: kAlignmentQ = GemmType::kMinimumAlignment (4 for fp16/bf16 SM75+)
// So check_supported (line 632) enforces DIFFERENT thresholds per path.
//
// The ONNX Attention kernel (core/providers/cuda/llm/attention.cc) gates MEA eligibility
// at kMinimumAlignment (4), allowing strides like 12 that the unaligned kernel handles.
// Without this check, such inputs dispatch to the aligned kernel where 12%8≠0 crashes.
// Contrib MHA gates at 4*sizeof(T)=8 for fp16, making this check redundant there.
if (params.attn_bias != nullptr) {
int num_keys = params.kv_sequence_length;
int num_queries = params.sequence_length;
int bias_strideM = num_keys;
// Broadcast dimensions use stride=0, which satisfies any alignment (0 % N == 0).
int bias_strideH = params.broadcast_attn_bias_dim_1 ? 0 : num_queries * num_keys;
int bias_strideB = params.broadcast_attn_bias_dim_0
? 0
: ((params.broadcast_attn_bias_dim_1 ? 1 : params.num_heads) * num_queries * num_keys);
is_aligned = is_aligned &&
bias_strideM % AlignedAK::kAlignmentQ == 0 &&
(params.num_heads <= 1 || bias_strideH % AlignedAK::kAlignmentQ == 0) &&
(params.batch_size <= 1 || bias_strideB % AlignedAK::kAlignmentQ == 0);
}

Comment thread
titaiwangms marked this conversation as resolved.
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned::value, queries_per_block, keys_per_block, max_head_size>(params);
}));
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,24 @@ template Status QkvToContext<__nv_bfloat16, __nv_fp8_e4m3>(
GroupQueryAttentionData<__nv_bfloat16, __nv_fp8_e4m3>& data);
#endif

// Explicit instantiations for cross-TU usage by core/providers/cuda/llm/attention.cc
template Status LaunchUngroup<__half>(
const GroupQueryAttentionParameters& parameters,
float2* k_buff, float2* v_buff,
const float2* k_og, const float2* v_og,
const int buff_seqlen, const int og_seqlen,
const bool is_bsnh,
cudaStream_t stream,
const int max_threads_per_block);
template Status LaunchUngroup<__nv_bfloat16>(
const GroupQueryAttentionParameters& parameters,
float2* k_buff, float2* v_buff,
const float2* k_og, const float2* v_og,
const int buff_seqlen, const int og_seqlen,
const bool is_bsnh,
cudaStream_t stream,
const int max_threads_per_block);

template Status LaunchUnpackQKV<half, LAYOUT_BNSH>(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);
template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block);

Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ struct GQABufferRequirements {
}
};

template <typename T>
// Also used by ONNX Attention (core/providers/cuda/llm/attention.cc) for GQA head expansion in MEA path.
Status LaunchUngroup(const GroupQueryAttentionParameters& parameters,
Comment thread
titaiwangms marked this conversation as resolved.
float2* k_buff, float2* v_buff,
const float2* k_og, const float2* v_og,
const int buff_seqlen, const int og_seqlen,
const bool is_bsnh,
cudaStream_t stream,
const int max_threads_per_block);

Status LaunchGetSequenceLengths(
const int* total_seq_lens_minus_one,
int* past_seq_lens,
Expand Down
28 changes: 22 additions & 6 deletions onnxruntime/core/providers/cpu/llm/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
TensorShape& y_shape,
TensorShape& present_key_shape,
TensorShape& present_value_shape,
TensorShape& output_qk_shape) {
TensorShape& output_qk_shape,
bool skip_nonpad_data_validation = false) {
ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr,
"Q, K, and V inputs must not be null");
int q_dims = onnxruntime::narrow<int>(Q->Shape().NumDimensions());
Expand Down Expand Up @@ -90,6 +91,13 @@

parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3)
parameters.q_sequence_length = onnxruntime::narrow<int>(Q->Shape()[1]);

// Validate mask second-to-last dim matches q_sequence_length (same check as 4D path).
// For 2D mask [A, B]: A must equal q_seq. For 3D mask [A, B, C]: B must equal q_seq.
ORT_ENFORCE(attn_mask == nullptr ||
attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[1],
"inconsistent q_sequence_length (between attn_mask and Q)");

parameters.head_size = onnxruntime::narrow<int>(Q->Shape()[2]) / parameters.q_num_heads;
parameters.kv_sequence_length = onnxruntime::narrow<int>(K->Shape()[1]);
parameters.v_head_size = onnxruntime::narrow<int>(V->Shape()[2]) / parameters.kv_num_heads;
Expand All @@ -115,18 +123,26 @@
parameters.has_nonpad_kv_seqlen = true;
parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data<int64_t>();
// Validate each value is in [0, total_sequence_length].
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
// Skip when data is on GPU (CUDA path sets skip_nonpad_data_validation=true).
if (!skip_nonpad_data_validation) {
for (int i = 0; i < parameters.batch_size; ++i) {
ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 &&
parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length,
"nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i],
" is out of range [0, ", parameters.total_sequence_length, "]");
}
}
} else {
parameters.has_nonpad_kv_seqlen = false;
parameters.nonpad_kv_seqlen_data = nullptr;
}

ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention.");
// TODO: The ONNX spec allows attn_mask last dim to be shorter than total_sequence_length,

Check warning on line 141 in onnxruntime/core/providers/cpu/llm/attention_helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cpu/llm/attention_helper.h:141: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// with positions beyond the mask padded with -inf. Currently we enforce exact match.
// To support: change == to <=, allocate padded buffer, fill remainder with -inf.
// See ONNX spec: 'The last dimension can also be shorter than total_sequence_length
// and will be padded to total_sequence_length with negative infinity.'
Comment thread
titaiwangms marked this conversation as resolved.
ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length,
"inconsistent total_sequence_length (between attn_mask and past_key and past_value)");
ORT_ENFORCE(attn_mask == nullptr ||
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1592,9 +1592,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);

// Opset 23.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization);
Expand Down Expand Up @@ -1633,6 +1633,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,

// Opset 24.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention);

#endif

Expand Down Expand Up @@ -2671,9 +2674,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,

// Opset 23
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization)>,
Expand Down Expand Up @@ -2711,6 +2714,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

// Opset 24
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, TensorScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, BFloat16, Attention)>,
#endif
};

Expand Down
Loading
Loading