diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 632f2561d522c..96494d38a410f 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -12,6 +12,9 @@ #include "core/util/math_cpuonly.h" #include "core/providers/cpu/math/gemm.h" +#include +#include + using onnxruntime::attention_helper::AttentionParameters; using onnxruntime::attention_helper::QKMatMulOutputMode; using onnxruntime::concurrency::ThreadPool; @@ -107,6 +110,95 @@ inline void ComputeAttentionSoftcapInplace(MLFloat16* scores, int sequence_lengt } } +// Dispatches a GEMM operation across float and MLFloat16 types. +// C = alpha * op(A) * op(B) + beta * C +// +// For float: delegates to math::GemmEx which calls MlasGemm (optimized SGEMM). +// For MLFloat16: +// - If the hardware supports native fp16 GEMM for the given transpose combo +// (checked via MlasHGemmSupported), uses MlasGemm directly. +// - Otherwise, upcasts A/B/C to fp32, runs math::GemmEx (SGEMM), and downcasts +// the result back to fp16. This avoids Eigen's unoptimized fp16 codepath. +// +// The fp32 fallback handles strided C carefully: when ldc > N (e.g. 3D interleaved +// heads where multiple heads share a row), conversion is done row-by-row (N elements +// per row) to avoid overwriting adjacent heads' data. When ldc == N (contiguous, +// the common 4D case), a single bulk conversion is used for efficiency. +// +// TODO(xadupre): Consider adding a MlasFlashAttention fast path for float32 when no masks, KV cache, +// softcap, or nonpad_kv_seqlen are active. This fuses Q*K, softmax, and QK*V into a single +// L2-cache-tiled pass. See MultiHeadAttention (contrib_ops/cpu/bert/multihead_attention.cc). +template +inline void AttentionGemm(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int M, int N, int K, + float alpha, + const T* A, int lda, + const T* B, int ldb, + float beta, + T* C, int ldc) { + if constexpr (std::is_same::value) { + math::GemmEx(transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, nullptr); + } else if constexpr (std::is_same::value) { + if (MlasHGemmSupported(transA, transB)) { + MlasGemm(transA, transB, M, N, K, A, lda, B, ldb, C, ldc, + MLFloat16(alpha).val, MLFloat16(beta).val, nullptr); + } else { + // fp16 fallback: upcast to fp32, run optimized SGEMM, downcast result. + // Compute the exact contiguous span each matrix occupies: (rows-1)*stride + cols. + // This is the distance from the first element to the last accessed element + 1. + // Using rows*stride would overread when the pointer is offset into a larger + // interleaved buffer (e.g., 3D layout where lda > K for a non-first head). + size_t a_rows = (transA == CblasNoTrans) ? static_cast(M) : static_cast(K); + size_t a_cols = (transA == CblasNoTrans) ? static_cast(K) : static_cast(M); + size_t b_rows = (transB == CblasNoTrans) ? static_cast(K) : static_cast(N); + size_t b_cols = (transB == CblasNoTrans) ? static_cast(N) : static_cast(K); + size_t a_count = (a_rows > 0) ? (a_rows - 1) * static_cast(lda) + a_cols : 0; + size_t b_count = (b_rows > 0) ? (b_rows - 1) * static_cast(ldb) + b_cols : 0; + size_t c_count = (M > 0) ? static_cast(M - 1) * static_cast(ldc) + static_cast(N) : 0; + + std::vector a_fp32(a_count); + std::vector b_fp32(b_count); + std::vector c_fp32(c_count); + + // Upcast A and B in bulk (contiguous within each matrix's strided span). + MlasConvertHalfToFloatBuffer(A, a_fp32.data(), a_count); + MlasConvertHalfToFloatBuffer(B, b_fp32.data(), b_count); + if (beta != 0.0f) { + // C needs upcast only when beta != 0 (GEMM accumulates into C). + // When ldc == N the buffer is contiguous — use a single bulk conversion. + // When ldc > N (3D interleaved heads), convert only the N valid columns + // per row to avoid reading into adjacent heads' memory. + if (ldc == N) { + MlasConvertHalfToFloatBuffer(C, c_fp32.data(), c_count); + } else { + for (int row = 0; row < M; ++row) { + MlasConvertHalfToFloatBuffer(C + row * ldc, c_fp32.data() + row * ldc, static_cast(N)); + } + } + } + + math::GemmEx(transA, transB, M, N, K, + alpha, a_fp32.data(), lda, + b_fp32.data(), ldb, + beta, c_fp32.data(), ldc, nullptr); + + // Downcast result back to fp16. + // Same ldc == N check: bulk conversion when contiguous, row-by-row when + // strided to avoid overwriting adjacent heads' output data. + if (ldc == N) { + MlasConvertFloatToHalfBuffer(c_fp32.data(), C, c_count); + } else { + for (int row = 0; row < M; ++row) { + MlasConvertFloatToHalfBuffer(c_fp32.data() + row * ldc, C + row * ldc, static_cast(N)); + } + } + } + } else { + ORT_THROW("Unsupported data type for attention GEMM: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } +} + template Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) { is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; @@ -139,6 +231,7 @@ Status Attention::Compute(OpKernelContext* context) const { const Tensor* attn_mask = context->Input(3); const Tensor* past_key = context->Input(4); const Tensor* past_value = context->Input(5); + const Tensor* nonpad_kv_seqlen = context->Input(6); // optional, Opset 24 AttentionParameters parameters; TensorShape y_shape; @@ -154,6 +247,7 @@ Status Attention::Compute(OpKernelContext* context) const { attn_mask, past_key, past_value, + nonpad_kv_seqlen, is_causal_, softcap_, softmax_precision_, @@ -354,90 +448,22 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, // A: Q (B x N x) S x H (B x N x) S x H S x H // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T - if constexpr (std::is_same::value) { - if (parameters.transpose_output) { - math::GemmEx(CblasNoTrans, - CblasTrans, - parameters.q_sequence_length, // M - parameters.total_sequence_length, // N - parameters.head_size, // K - alpha, - Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, - parameters.head_size * parameters.q_num_heads, // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, - transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb - beta, - output, - parameters.total_sequence_length, // ldc - nullptr); - } else { - math::Gemm(CblasNoTrans, - CblasTrans, - parameters.q_sequence_length, // M - parameters.total_sequence_length, // N - parameters.head_size, // K - alpha, - Q + q_input_chunk_length * i, - k, - beta, - output, - nullptr); - } - } else if constexpr (std::is_same::value) { - if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { - MlasGemm(CblasNoTrans, - CblasTrans, - parameters.q_sequence_length, // M - parameters.total_sequence_length, // N - parameters.head_size, // K - parameters.transpose_output - ? Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size - : Q + q_input_chunk_length * i, - parameters.transpose_output - ? parameters.head_size * parameters.q_num_heads - : static_cast(parameters.head_size), // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, - transposed_k - ? parameters.head_size * parameters.kv_num_heads - : static_cast(parameters.head_size), // ldb - output, - static_cast(parameters.past_sequence_length + parameters.kv_sequence_length), // ldc - MLFloat16(alpha).val, MLFloat16(beta).val, nullptr); - } else { - if (parameters.transpose_output) { - math::GemmEx(CblasNoTrans, - CblasTrans, - parameters.q_sequence_length, // M - parameters.total_sequence_length, // N - parameters.head_size, // K - MLFloat16(alpha), - Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, - parameters.head_size * parameters.q_num_heads, // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, - transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb - MLFloat16(beta), - output, - parameters.total_sequence_length, // ldc - nullptr); - } else { - TensorShape c_shape({parameters.q_sequence_length, parameters.total_sequence_length}); - Gemm_MLFloat16(CblasNoTrans, CblasTrans, - static_cast(parameters.q_sequence_length), // M - static_cast(parameters.total_sequence_length), // N - static_cast(parameters.head_size), // K - MLFloat16(alpha), - Q + q_input_chunk_length * i, - k, - MLFloat16(beta), - output, - &c_shape, - output, - nullptr); - } - } - } else { - ORT_THROW("Unsupported data type for attention Q*K multiplication: ", DataTypeImpl::ToString(DataTypeImpl::GetType())); - } + const T* q_ptr = parameters.transpose_output + ? Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size + : Q + q_input_chunk_length * i; + int q_lda = parameters.transpose_output + ? parameters.head_size * parameters.q_num_heads + : parameters.head_size; + const T* k_ptr = transposed_k + ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size + : k; + int k_ldb = transposed_k + ? parameters.head_size * parameters.kv_num_heads + : parameters.head_size; + + AttentionGemm(CblasNoTrans, CblasTrans, + parameters.q_sequence_length, parameters.total_sequence_length, parameters.head_size, + alpha, q_ptr, q_lda, k_ptr, k_ldb, beta, output, parameters.total_sequence_length); if (out_qk != nullptr && (parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask || parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK)) { @@ -448,6 +474,15 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, MlasEltwiseAdd(output, mask_data + mask_data_offset, output, probs_matrix_size); } } + // Apply nonpad_kv_seqlen masking (Opset 24+): mask out KV positions >= valid length per batch. + if (parameters.has_nonpad_kv_seqlen) { + int valid_kv_len = static_cast(parameters.nonpad_kv_seqlen_data[batch_i]); + for (int s = 0; s < parameters.q_sequence_length; ++s) { + std::fill(output + s * parameters.total_sequence_length + valid_kv_len, + output + (s + 1) * parameters.total_sequence_length, + mask_filter_value()); + } + } if (parameters.softcap > 0.0f) { if constexpr (std::is_same::value) { ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), parameters.softcap); @@ -587,104 +622,32 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu } } + // Compute QK * V + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + const T* gemm_B; + int gemm_ldb; + T* gemm_C; + int gemm_ldc; + if (transpose_output) { - // transpose_output is false - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; - - if constexpr (std::is_same::value) { - // V is transposed but not QK. We use GemmEx with a different value for ldb. - math::GemmEx(CblasNoTrans, - CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - 1.f, // alpha - attention_probs + attention_probs_offset, // QK - total_sequence_length, // lda - transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V - transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb - 0.f, // beta - output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), - v_head_size * num_heads, // ldc - nullptr); - } else if constexpr (std::is_same::value) { - // This switch should probably be moved to math_cpu.h. - if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { - MlasGemm(CblasNoTrans, - CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - attention_probs + attention_probs_offset, - total_sequence_length, // lda - transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, - transposed_v ? static_cast(v_head_size * kv_num_heads) : static_cast(v_head_size), // ldb - output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), - v_head_size * num_heads, // ldc - MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); - } else { - math::GemmEx(CblasNoTrans, - CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - MLFloat16(1.f), // alpha - attention_probs + attention_probs_offset, // QK - total_sequence_length, // lda - transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V - transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb - MLFloat16(0.f), // beta - output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), - v_head_size * num_heads, // ldc - nullptr); - } - } else { - ORT_THROW("Unsupported data type for attention QK*V multiplication: ", - DataTypeImpl::ToString(DataTypeImpl::GetType())); - } + // 3D inputs: V may be in strided layout, use appropriate strides. + gemm_B = transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v; + gemm_ldb = transposed_v ? v_head_size * kv_num_heads : v_head_size; + gemm_C = output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size); + gemm_ldc = v_head_size * num_heads; } else { - // transpose_output is false - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + // 4D inputs: V is already in head-contiguous layout. + gemm_B = v; + gemm_ldb = v_head_size; ptrdiff_t dest_offset = SafeInt(sequence_length) * v_head_size * i; - T* dest = output + dest_offset; - - if constexpr (std::is_same::value) { - math::MatMul(sequence_length, v_head_size, total_sequence_length, - attention_probs + attention_probs_offset, v, dest, nullptr); - } else if constexpr (std::is_same::value) { - if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { - MlasGemm(CblasNoTrans, - CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - attention_probs + attention_probs_offset, - total_sequence_length, // lda - v, - static_cast(v_head_size), // ldb - dest, - static_cast(v_head_size), // ldc - MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); - } else { - Gemm_MLFloat16(CblasNoTrans, - CblasNoTrans, - static_cast(sequence_length), // M - static_cast(v_head_size), // N - static_cast(total_sequence_length), // K - MLFloat16(1.f), // alpha - attention_probs + attention_probs_offset, - v, - MLFloat16(0.f), // beta - nullptr, - nullptr, - dest, - nullptr); - } - } else { - ORT_THROW("Unsupported data type for attention QK*V multiplication: ", - DataTypeImpl::ToString(DataTypeImpl::GetType())); - } + gemm_C = output + dest_offset; + gemm_ldc = v_head_size; } + + AttentionGemm(CblasNoTrans, CblasNoTrans, + sequence_length, v_head_size, total_sequence_length, + 1.0f, attention_probs + attention_probs_offset, total_sequence_length, + gemm_B, gemm_ldb, 0.0f, gemm_C, gemm_ldc); } }); } diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h index 4b1e22df4b2f6..e7df1a078472a 100644 --- a/onnxruntime/core/providers/cpu/llm/attention_helper.h +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -15,6 +15,7 @@ inline Status ComputeOutputShapeForAttention( const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, bool is_causal, float softcap, int softmax_precision, @@ -102,6 +103,29 @@ inline Status ComputeOutputShapeForAttention( } parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; + // Handle nonpad_kv_seqlen (Opset 24+) + if (nonpad_kv_seqlen != nullptr) { + ORT_ENFORCE(nonpad_kv_seqlen->Shape().NumDimensions() == 1, + "nonpad_kv_seqlen must be a 1D tensor"); + ORT_ENFORCE(nonpad_kv_seqlen->Shape()[0] == parameters.batch_size, + "nonpad_kv_seqlen must have shape [batch_size], got ", + nonpad_kv_seqlen->Shape()[0], " vs batch_size=", parameters.batch_size); + ORT_ENFORCE(past_key == nullptr && past_value == nullptr, + "nonpad_kv_seqlen should not be used together with past_key and past_value inputs"); + parameters.has_nonpad_kv_seqlen = true; + parameters.nonpad_kv_seqlen_data = nonpad_kv_seqlen->Data(); + // Validate each value is in [0, total_sequence_length]. + for (int i = 0; i < parameters.batch_size; ++i) { + ORT_ENFORCE(parameters.nonpad_kv_seqlen_data[i] >= 0 && + parameters.nonpad_kv_seqlen_data[i] <= parameters.total_sequence_length, + "nonpad_kv_seqlen[", i, "] = ", parameters.nonpad_kv_seqlen_data[i], + " is out of range [0, ", parameters.total_sequence_length, "]"); + } + } else { + parameters.has_nonpad_kv_seqlen = false; + parameters.nonpad_kv_seqlen_data = nullptr; + } + ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads must be a multiple of kv_num_heads. This is required for grouped/multi-query and multi-headed attention."); ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); diff --git a/onnxruntime/core/providers/cpu/llm/attention_parameters.h b/onnxruntime/core/providers/cpu/llm/attention_parameters.h index b8586ca4d63dc..6dd5beb4fed45 100644 --- a/onnxruntime/core/providers/cpu/llm/attention_parameters.h +++ b/onnxruntime/core/providers/cpu/llm/attention_parameters.h @@ -45,6 +45,10 @@ struct AttentionParameters { bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH // This covers the case where the inputs are 3D. + // nonpad_kv_seqlen (Opset 24+): per-batch valid KV sequence lengths, shape [batch_size] + bool has_nonpad_kv_seqlen = false; + const int64_t* nonpad_kv_seqlen_data = nullptr; + // Checks the consistency of the parameters. void checkParameters() const { ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index ef0dd065db523..75531fd6303bd 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -67,6 +67,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* attn_mask = context->Input(3); const Tensor* past_key = context->Input(4); const Tensor* past_value = context->Input(5); + const Tensor* nonpad_kv_seqlen = context->Input(6); // optional, Opset 24 attention_helper::AttentionParameters parameters; TensorShape y_shape; @@ -81,6 +82,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { attn_mask, past_key, past_value, + nonpad_kv_seqlen, is_causal_, softcap_, softmax_precision_, diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 358f775cf2bc0..86bbfb172c7ee 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -1414,5 +1414,208 @@ TEST(AttentionTest, AttentionNoPastWithPresentOutput) { ); } +// Test nonpad_kv_seqlen (Opset 24 feature). +// nonpad_kv_seqlen masks out KV positions >= valid length per batch element. +TEST(AttentionTest, Attention_NonPadKVSeqLen_4D) { + // batch_size=1, q_num_heads=1, kv_num_heads=1 + // q_seq_len=1, kv_seq_len=4, head_size=2, v_head_size=2 + // nonpad_kv_seqlen=[2] => only first 2 of 4 KV positions are valid + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + // 4D inputs: (batch, heads, seq, head_size) + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + // Q and K all 1.0 so QK scores are uniform + std::vector q = {1.0f, 1.0f}; + std::vector k = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + // V: first 2 positions = 1.0, last 2 = 99.0 (should be masked out) + std::vector v = {1.0f, 1.0f, 1.0f, 1.0f, 99.0f, 99.0f, 99.0f, 99.0f}; + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {1}, {2}); + + // Uniform attention over 2 valid positions with V=1.0 => output is [1.0, 1.0] + std::vector expected_y = {1.0f, 1.0f}; + test.AddOutput("Y", {1, 1, 1, 2}, expected_y, false, 0, 1e-4f); + + // Per spec, present_key/present_value should not be used with nonpad_kv_seqlen. + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test nonpad_kv_seqlen with batch_size > 1 and different valid lengths per batch. +TEST(AttentionTest, Attention_NonPadKVSeqLen_MultiBatch_4D) { + // batch_size=2, q_num_heads=1, kv_num_heads=1 + // q_seq_len=1, kv_seq_len=4, head_size=2, v_head_size=2 + // nonpad_kv_seqlen=[2, 3] => batch 0 has 2 valid, batch 1 has 3 valid + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {2, 1, 1, 2}; + std::vector k_shape = {2, 1, 4, 2}; + std::vector v_shape = {2, 1, 4, 2}; + + // Q and K all 1.0 + std::vector q = {1.0f, 1.0f, 1.0f, 1.0f}; + std::vector k(2 * 1 * 4 * 2, 1.0f); + + // V for batch 0: [1, 1], [1, 1], [99, 99], [99, 99] + // V for batch 1: [2, 2], [2, 2], [2, 2], [99, 99] + std::vector v = { + 1.0f, 1.0f, 1.0f, 1.0f, 99.0f, 99.0f, 99.0f, 99.0f, // batch 0 + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 99.0f, 99.0f // batch 1 + }; + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {2}, {2, 3}); + + // Batch 0: uniform over 2 valid positions, V=1.0 => [1.0, 1.0] + // Batch 1: uniform over 3 valid positions, V=2.0 => [2.0, 2.0] + std::vector expected_y = {1.0f, 1.0f, 2.0f, 2.0f}; + test.AddOutput("Y", {2, 1, 1, 2}, expected_y, false, 0, 1e-4f); + + // Per spec, present_key/present_value should not be used with nonpad_kv_seqlen. + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Edge case: nonpad_kv_seqlen = 0 (all positions masked). +TEST(AttentionTest, Attention_NonPadKVSeqLen_AllMasked) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + std::vector q = {1.0f, 1.0f}; + std::vector k(8, 1.0f); + // All V positions are "invalid" — result is uniform over 4 highly-negative-scored positions, + // so softmax is uniform 0.25 each. + std::vector v = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f, 70.0f, 80.0f}; + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("nonpad_kv_seqlen", {1}, {0}); + + // With all positions masked to -inf, softmax produces uniform weights and + // the result is the mean of all V rows: [(10+30+50+70)/4, (20+40+60+80)/4] = [40, 50]. + std::vector expected_y = {40.0f, 50.0f}; + test.AddOutput("Y", {1, 1, 1, 2}, expected_y, false, 0, 1e-3f); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Edge case: nonpad_kv_seqlen = total_sequence_length (no positions masked). +TEST(AttentionTest, Attention_NonPadKVSeqLen_NoneMasked) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + std::vector q = {1.0f, 1.0f}; + std::vector k(8, 1.0f); + std::vector v = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + // All 4 KV positions are valid — no masking. + test.AddInput("nonpad_kv_seqlen", {1}, {4}); + + // Uniform attention over all 4 positions: mean of V rows. + // [(1+3+5+7)/4, (2+4+6+8)/4] = [4.0, 5.0] + std::vector expected_y = {4.0f, 5.0f}; + test.AddOutput("Y", {1, 1, 1, 2}, expected_y, false, 0, 1e-3f); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Validation: negative nonpad_kv_seqlen should be rejected. +TEST(AttentionTest, Attention_NonPadKVSeqLen_NegativeValue) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + test.AddInput("Q", q_shape, {1.0f, 1.0f}); + test.AddInput("K", k_shape, std::vector(8, 1.0f)); + test.AddInput("V", v_shape, std::vector(8, 1.0f)); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("nonpad_kv_seqlen", {1}, {-1}); + + test.AddOutput("Y", {1, 1, 1, 2}, {0.0f, 0.0f}); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "nonpad_kv_seqlen[0] = -1 is out of range", + {}, nullptr, &execution_providers); +} + +// Validation: nonpad_kv_seqlen exceeding total_sequence_length should be rejected. +TEST(AttentionTest, Attention_NonPadKVSeqLen_ExceedsTotalSeqLen) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + std::vector q_shape = {1, 1, 1, 2}; + std::vector k_shape = {1, 1, 4, 2}; + std::vector v_shape = {1, 1, 4, 2}; + + test.AddInput("Q", q_shape, {1.0f, 1.0f}); + test.AddInput("K", k_shape, std::vector(8, 1.0f)); + test.AddInput("V", v_shape, std::vector(8, 1.0f)); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + test.AddInput("nonpad_kv_seqlen", {1}, {5}); // total_sequence_length=4 + + test.AddOutput("Y", {1, 1, 1, 2}, {0.0f, 0.0f}); + test.AddOptionalOutputEdge(); + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, "nonpad_kv_seqlen[0] = 5 is out of range", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime