Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,24 @@ export const validateInputs = (
passPastInKv = true;
}
}
// Spec requires 1D shape (batch_size), but older model builders may add unit
// dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batchSize.
const seqlLens = inputs.length > 4 ? inputs[5] : undefined;
if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) {
throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size');
if (seqlLens) {
if (seqlLens.dims.length === 0) {
throw new Error('seqlens_k must be at least 1D, got scalar.');
}
const seqlLenSize = seqlLens.dims.reduce((a, b) => a * b, 1);
if (seqlLenSize !== batchSize) {
throw new Error(`seqlens_k must have batch_size (${batchSize}) elements, got ${seqlLenSize}.`);
}
for (let i = 0; i < seqlLens.dims.length; i++) {
if (seqlLens.dims[i] !== 1 && seqlLens.dims[i] !== batchSize) {
throw new Error(
`seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (${batchSize}), got dims[${i}] = ${seqlLens.dims[i]}.`,
);
}
}
}
Comment thread
vraspar marked this conversation as resolved.
const totalSequenceLength = -1;
const maxSequenceLength = -1;
Expand Down
78 changes: 78 additions & 0 deletions js/web/test/data/ops/group-query-attention.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1409,5 +1409,83 @@
]
}
]
},
{
// Backward compat: seqlens_k shape [1, 1] accepted for batch_size=1.
// Older model builders (e.g. qwen3-0.6b) emit this instead of [1].
"name": "GroupQueryAttention Legacy2D SeqlensK",
"operator": "GroupQueryAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "num_heads", "data": 1, "type": "int" },
{ "name": "kv_num_heads", "data": 1, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [1, 1, 8],
"type": "float32"
},
// key
{
"data": [16, 17, 18, 19, 20, 21, 22, 23],
"dims": [1, 1, 8],
"type": "float32"
},
// value
{
"data": [32, 33, 34, 35, 36, 37, 38, 39],
"dims": [1, 1, 8],
"type": "float32"
},
// past key, BNSH
{
"data": [],
"dims": [1, 1, 0, 8],
"type": "float32"
},
// past value, BNSH
{
"data": [],
"dims": [1, 1, 0, 8],
"type": "float32"
},
// seqlens_k -- legacy [1, 1] shape instead of [1]
{
"data": [1],
"dims": [1, 1],
"type": "int32"
},
// total_sequence_length
{
"data": [1],
"dims": [1],
"type": "int32"
}
],
"outputs": [
{
"data": [32, 33, 34, 35, 36, 37, 38, 39],
"dims": [1, 1, 8],
"type": "float32"
},
{
// present key, BNSH
"data": [16, 17, 18, 19, 20, 21, 22, 23],
"dims": [1, 1, 1, 8],
"type": "float32"
},
{
// present value, BNSH
"data": [32, 33, 34, 35, 36, 37, 38, 39],
"dims": [1, 1, 1, 8],
"type": "float32"
}
]
}
]
}
]
18 changes: 16 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,24 @@ Status CheckInputs(const T* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}

// Spec requires 1D shape (batch_size), but older model builders may add unit
// dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size.
Comment thread
vraspar marked this conversation as resolved.
if (seqlens_k->Shape().NumDimensions() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be at least 1D, got scalar.");
}
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) {
if (seqlens_k->Shape().Size() != static_cast<int64_t>(batch_size)) {
Comment thread
edgchen1 marked this conversation as resolved.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
"seqlens_k must have batch_size (", batch_size, ") elements, got ",
seqlens_k->Shape().Size(), ".");
}
for (size_t i = 0; i < seqlens_k_dim.size(); ++i) {
if (seqlens_k_dim[i] != 1 && seqlens_k_dim[i] != static_cast<int64_t>(batch_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (",
batch_size, "), got dim[", i, "] = ", seqlens_k_dim[i], ".");
}
Comment thread
vraspar marked this conversation as resolved.
}

if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
Expand Down
168 changes: 98 additions & 70 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <limits>
#include <optional>

#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
Expand All @@ -22,7 +23,8 @@ static void RunGQASeqlensKTest(
OpTester::ExpectResult expect,
const std::string& expected_message,
bool provide_past = false,
int past_seq_len = 0) {
int past_seq_len = 0,
const std::optional<std::vector<int64_t>>& seqlens_k_shape = std::nullopt) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 8;
Expand Down Expand Up @@ -52,7 +54,10 @@ static void RunGQASeqlensKTest(
tester.AddOptionalInputEdge<float>(); // past_value
}

tester.AddInput<int32_t>("seqlens_k", {batch_size}, seqlens_k_data);
std::vector<int64_t> shape = seqlens_k_shape.has_value()
? *seqlens_k_shape
: std::vector<int64_t>{batch_size};
tester.AddInput<int32_t>("seqlens_k", shape, seqlens_k_data);
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

tester.AddOptionalInputEdge<float>(); // cos_cache
Expand All @@ -73,8 +78,7 @@ static void RunGQASeqlensKTest(
{batch_size, kv_num_heads, declared_present_seqlen, head_size},
std::vector<float>(batch_size * kv_num_heads * declared_present_seqlen * head_size, 0.0f));

// For success tests, we only care that validation passes without crash;
// exact output values are not the focus of these security regression tests.
// Tolerance is intentionally loose: these tests validate shape acceptance, not output values.
if (expect == OpTester::ExpectResult::kExpectSuccess) {
tester.SetOutputTolerance(1e6f);
}
Expand Down Expand Up @@ -231,80 +235,104 @@ TEST(GroupQueryAttentionTest, TotalSeqLenNegative) {
"total_sequence_length must be positive");
}

// Shape validation: seqlens_k with wrong rank (2D instead of 1D) must be rejected.
TEST(GroupQueryAttentionTest, SeqlensKWrongRank) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 8;
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
// Backward compat: seqlens_k shape {1, 1} accepted for batch_size=1.
// Older model builders (e.g. qwen3-0.6b) emit this instead of {1}.
TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShape) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0},
/*total_seq_len=*/1,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectSuccess,
"",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{1, 1});
}

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddOptionalInputEdge<float>(); // past_key
tester.AddOptionalInputEdge<float>(); // past_value
// 2D shape {1, 1} instead of {1}
tester.AddInput<int32_t>("seqlens_k", {1, 1}, {0});
tester.AddInput<int32_t>("total_sequence_length", {1}, {1});
tester.AddOptionalInputEdge<float>(); // cos_cache
tester.AddOptionalInputEdge<float>(); // sin_cache
tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink
// Backward compat: seqlens_k shape {2, 1} accepted for batch_size=2.
TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShapeMultiBatch) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0, 0},
/*total_seq_len=*/1,
/*batch_size=*/2,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectSuccess,
"",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{2, 1});
}

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, 1, head_size},
std::vector<float>(kv_num_heads * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, 1, head_size},
std::vector<float>(kv_num_heads * head_size, 0.0f));
// Backward compat: seqlens_k shape {1, 2} accepted for batch_size=2.
// Batch dimension in trailing position.
TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShapeTrailingBatch) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0, 0},
/*total_seq_len=*/1,
/*batch_size=*/2,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectSuccess,
"",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{1, 2});
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)",
{}, nullptr, &execution_providers);
// Shape {2, 2} with batch_size=4: correct element count but invalid factored shape.
TEST(GroupQueryAttentionTest, SeqlensKInvalidFactoredShape) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0, 0, 0, 0},
/*total_seq_len=*/1,
/*batch_size=*/4,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"seqlens_k has unexpected shape",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{2, 2});
}

// Shape validation: seqlens_k with wrong length (2 elements for batch_size=1) must be rejected.
// Wrong element count (1D): 2 elements for batch_size=1.
TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 8;
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddOptionalInputEdge<float>(); // past_key
tester.AddOptionalInputEdge<float>(); // past_value
// Length 2 instead of 1 for batch_size=1
tester.AddInput<int32_t>("seqlens_k", {2}, {0, 0});
tester.AddInput<int32_t>("total_sequence_length", {1}, {1});
tester.AddOptionalInputEdge<float>(); // cos_cache
tester.AddOptionalInputEdge<float>(); // sin_cache
tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0, 0},
/*total_seq_len=*/1,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"seqlens_k must have batch_size",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{2});
}

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, 1, head_size},
std::vector<float>(kv_num_heads * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, 1, head_size},
std::vector<float>(kv_num_heads * head_size, 0.0f));
// Wrong element count (2D): shape {2, 1} has 2 elements but batch_size=1.
TEST(GroupQueryAttentionTest, SeqlensKWrongElementCount2D) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0, 0},
/*total_seq_len=*/1,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"seqlens_k must have batch_size",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{2, 1});
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)",
{}, nullptr, &execution_providers);
// Scalar seqlens_k must be rejected even when batch_size=1.
TEST(GroupQueryAttentionTest, SeqlensKScalarRejected) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{0},
/*total_seq_len=*/1,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"seqlens_k must be at least 1D",
/*provide_past=*/false,
/*past_seq_len=*/0,
/*seqlens_k_shape=*/std::vector<int64_t>{});
}

} // namespace test
Expand Down
Loading