diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index d218be3ce8b5f..9050c1bbb8816 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -193,9 +193,24 @@ export const validateInputs = ( passPastInKv = true; } } + // Spec requires 1D shape (batch_size), but older model builders may add unit + // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batchSize. const seqlLens = inputs.length > 4 ? inputs[5] : undefined; - if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { - throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); + if (seqlLens) { + if (seqlLens.dims.length === 0) { + throw new Error('seqlens_k must be at least 1D, got scalar.'); + } + const seqlLenSize = seqlLens.dims.reduce((a, b) => a * b, 1); + if (seqlLenSize !== batchSize) { + throw new Error(`seqlens_k must have batch_size (${batchSize}) elements, got ${seqlLenSize}.`); + } + for (let i = 0; i < seqlLens.dims.length; i++) { + if (seqlLens.dims[i] !== 1 && seqlLens.dims[i] !== batchSize) { + throw new Error( + `seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (${batchSize}), got dims[${i}] = ${seqlLens.dims[i]}.`, + ); + } + } } const totalSequenceLength = -1; const maxSequenceLength = -1; diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc index f71e89f727cb1..83a5dc765280e 100644 --- a/js/web/test/data/ops/group-query-attention.jsonc +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -1409,5 +1409,83 @@ ] } ] + }, + { + // Backward compat: seqlens_k shape [1, 1] accepted for batch_size=1. + // Older model builders (e.g. qwen3-0.6b) emit this instead of [1]. + "name": "GroupQueryAttention Legacy2D SeqlensK", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // past key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // past value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k -- legacy [1, 1] shape instead of [1] + { + "data": [1], + "dims": [1, 1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f5399e307fbca..0269523e0f34e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -261,10 +261,24 @@ Status CheckInputs(const T* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } + // Spec requires 1D shape (batch_size), but older model builders may add unit + // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. + if (seqlens_k->Shape().NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k must be at least 1D, got scalar."); + } const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); - if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) { + if (seqlens_k->Shape().Size() != static_cast(batch_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); + "seqlens_k must have batch_size (", batch_size, ") elements, got ", + seqlens_k->Shape().Size(), "."); + } + for (size_t i = 0; i < seqlens_k_dim.size(); ++i) { + if (seqlens_k_dim[i] != 1 && seqlens_k_dim[i] != static_cast(batch_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (", + batch_size, "), got dim[", i, "] = ", seqlens_k_dim[i], "."); + } } if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 0690094031bb8..508d8d0f200ac 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" @@ -22,7 +23,8 @@ static void RunGQASeqlensKTest( OpTester::ExpectResult expect, const std::string& expected_message, bool provide_past = false, - int past_seq_len = 0) { + int past_seq_len = 0, + const std::optional>& seqlens_k_shape = std::nullopt) { constexpr int num_heads = 1; constexpr int kv_num_heads = 1; constexpr int head_size = 8; @@ -52,7 +54,10 @@ static void RunGQASeqlensKTest( tester.AddOptionalInputEdge(); // past_value } - tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); + std::vector shape = seqlens_k_shape.has_value() + ? *seqlens_k_shape + : std::vector{batch_size}; + tester.AddInput("seqlens_k", shape, seqlens_k_data); tester.AddInput("total_sequence_length", {1}, {total_seq_len}); tester.AddOptionalInputEdge(); // cos_cache @@ -73,8 +78,7 @@ static void RunGQASeqlensKTest( {batch_size, kv_num_heads, declared_present_seqlen, head_size}, std::vector(batch_size * kv_num_heads * declared_present_seqlen * head_size, 0.0f)); - // For success tests, we only care that validation passes without crash; - // exact output values are not the focus of these security regression tests. + // Tolerance is intentionally loose: these tests validate shape acceptance, not output values. if (expect == OpTester::ExpectResult::kExpectSuccess) { tester.SetOutputTolerance(1e6f); } @@ -231,80 +235,104 @@ TEST(GroupQueryAttentionTest, TotalSeqLenNegative) { "total_sequence_length must be positive"); } -// Shape validation: seqlens_k with wrong rank (2D instead of 1D) must be rejected. -TEST(GroupQueryAttentionTest, SeqlensKWrongRank) { - constexpr int num_heads = 1; - constexpr int kv_num_heads = 1; - constexpr int head_size = 8; - constexpr int hidden_size = num_heads * head_size; - constexpr int kv_hidden_size = kv_num_heads * head_size; - - OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(num_heads)); - tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); +// Backward compat: seqlens_k shape {1, 1} accepted for batch_size=1. +// Older model builders (e.g. qwen3-0.6b) emit this instead of {1}. +TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShape) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0}, + /*total_seq_len=*/1, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectSuccess, + "", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{1, 1}); +} - tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); - tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); - tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); - tester.AddOptionalInputEdge(); // past_key - tester.AddOptionalInputEdge(); // past_value - // 2D shape {1, 1} instead of {1} - tester.AddInput("seqlens_k", {1, 1}, {0}); - tester.AddInput("total_sequence_length", {1}, {1}); - tester.AddOptionalInputEdge(); // cos_cache - tester.AddOptionalInputEdge(); // sin_cache - tester.AddOptionalInputEdge(); // position_ids - tester.AddOptionalInputEdge(); // attention_bias - tester.AddOptionalInputEdge(); // head_sink +// Backward compat: seqlens_k shape {2, 1} accepted for batch_size=2. +TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShapeMultiBatch) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0, 0}, + /*total_seq_len=*/1, + /*batch_size=*/2, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectSuccess, + "", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{2, 1}); +} - tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); - tester.AddOutput("present_key", {1, kv_num_heads, 1, head_size}, - std::vector(kv_num_heads * head_size, 0.0f)); - tester.AddOutput("present_value", {1, kv_num_heads, 1, head_size}, - std::vector(kv_num_heads * head_size, 0.0f)); +// Backward compat: seqlens_k shape {1, 2} accepted for batch_size=2. +// Batch dimension in trailing position. +TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShapeTrailingBatch) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0, 0}, + /*total_seq_len=*/1, + /*batch_size=*/2, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectSuccess, + "", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{1, 2}); +} - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)", - {}, nullptr, &execution_providers); +// Shape {2, 2} with batch_size=4: correct element count but invalid factored shape. +TEST(GroupQueryAttentionTest, SeqlensKInvalidFactoredShape) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0, 0, 0, 0}, + /*total_seq_len=*/1, + /*batch_size=*/4, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "seqlens_k has unexpected shape", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{2, 2}); } -// Shape validation: seqlens_k with wrong length (2 elements for batch_size=1) must be rejected. +// Wrong element count (1D): 2 elements for batch_size=1. TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { - constexpr int num_heads = 1; - constexpr int kv_num_heads = 1; - constexpr int head_size = 8; - constexpr int hidden_size = num_heads * head_size; - constexpr int kv_hidden_size = kv_num_heads * head_size; - - OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(num_heads)); - tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); - - tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); - tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); - tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); - tester.AddOptionalInputEdge(); // past_key - tester.AddOptionalInputEdge(); // past_value - // Length 2 instead of 1 for batch_size=1 - tester.AddInput("seqlens_k", {2}, {0, 0}); - tester.AddInput("total_sequence_length", {1}, {1}); - tester.AddOptionalInputEdge(); // cos_cache - tester.AddOptionalInputEdge(); // sin_cache - tester.AddOptionalInputEdge(); // position_ids - tester.AddOptionalInputEdge(); // attention_bias - tester.AddOptionalInputEdge(); // head_sink + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0, 0}, + /*total_seq_len=*/1, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "seqlens_k must have batch_size", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{2}); +} - tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); - tester.AddOutput("present_key", {1, kv_num_heads, 1, head_size}, - std::vector(kv_num_heads * head_size, 0.0f)); - tester.AddOutput("present_value", {1, kv_num_heads, 1, head_size}, - std::vector(kv_num_heads * head_size, 0.0f)); +// Wrong element count (2D): shape {2, 1} has 2 elements but batch_size=1. +TEST(GroupQueryAttentionTest, SeqlensKWrongElementCount2D) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0, 0}, + /*total_seq_len=*/1, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "seqlens_k must have batch_size", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{2, 1}); +} - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)", - {}, nullptr, &execution_providers); +// Scalar seqlens_k must be rejected even when batch_size=1. +TEST(GroupQueryAttentionTest, SeqlensKScalarRejected) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{0}, + /*total_seq_len=*/1, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "seqlens_k must be at least 1D", + /*provide_past=*/false, + /*past_seq_len=*/0, + /*seqlens_k_shape=*/std::vector{}); } } // namespace test