Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true");
// Validate seqlens_k values against cos/sin cache size to prevent OOB in rotary embedding lookup.
// Use the minimum of cos_cache and sin_cache dim-0 since CheckRotaryCaches does not enforce equality.
{
const int rotary_cache_max_seq = static_cast<int>(std::min(cos_cache->Shape().GetDims()[0],

Check warning on line 150 in onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc:150: Add #include <algorithm> for min [build/include_what_you_use] [4]
sin_cache->Shape().GetDims()[0]));
const int32_t* seqlens_k_data = seqlens_k->Data<int32_t>();
for (int b = 0; b < batch_size; b++) {
// position_id = seqlens_k[b] (in token generation), must be < cache rows
if (seqlens_k_data[b] >= rotary_cache_max_seq) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" exceeds rotary cache dimension 0 (", rotary_cache_max_seq, ")");
}
}
}
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,5 +307,115 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {
{}, nullptr, &execution_providers);
}

// Regression: seqlens_k valid for KV cache but exceeding cos_cache.shape[0] must be rejected
// when do_rotary is enabled. Without this check, the position ID derived from seqlens_k
// would index out of bounds in the cos/sin cache, leaking heap memory into output.
TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_OOB) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 16; // must be multiple of 16 for rotary
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;
constexpr int rotary_half_dim = head_size / 2; // cos/sin cache dim-1 = 8

constexpr int cos_cache_max_seq = 4; // small rotary cache
constexpr int past_seq_len = 16; // large KV cache
constexpr int seqlens_k_val = 10; // valid for KV (10 < 16) but OOB for cos (10 >= 4)
constexpr int total_seq_len = 4; // passes CheckRotaryCaches (4 <= cos_cache_max_seq)

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));

// Past KV cache is large enough for seqlens_k=10
tester.AddInput<float>("past_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));
tester.AddInput<float>("past_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));

tester.AddInput<int32_t>("seqlens_k", {1}, {seqlens_k_val});
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

// cos/sin cache with only 4 rows — seqlens_k=10 exceeds this
tester.AddInput<float>("cos_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 1.0f));
tester.AddInput<float>("sin_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 0.0f));

tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "exceeds rotary cache dimension 0",
{}, nullptr, &execution_providers);
}

// Positive test: seqlens_k within cos/sin cache bounds with do_rotary enabled should succeed.
TEST(GroupQueryAttentionTest, SeqlensKWithinCosCache_Rotary) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 16; // must be multiple of 16 for rotary
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;
constexpr int rotary_half_dim = head_size / 2;

constexpr int cos_cache_max_seq = 16; // rotary cache large enough
constexpr int past_seq_len = 16;
constexpr int seqlens_k_val = 3; // valid: 3 < 16 (cos cache) and 3 < 16 (KV cache)
constexpr int total_seq_len = 4; // seqlens_k + 1

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));

tester.AddInput<float>("past_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));
tester.AddInput<float>("past_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));

tester.AddInput<int32_t>("seqlens_k", {1}, {seqlens_k_val});
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

tester.AddInput<float>("cos_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 1.0f));
tester.AddInput<float>("sin_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 0.0f));

tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));

tester.SetOutputTolerance(1e6f); // shape acceptance test, not numerical correctness

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "",
{}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading