Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true");
// Validation of seqlens_k against rotary cache size is now performed in CheckInputs()
// to ensure all execution providers (CPU, CUDA, etc.) get the protection in one place.

// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ Status CheckInputs(const T* query,
int rotary_dim = 0;
if (cos_cache != nullptr && sin_cache != nullptr) {
ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim));

// Validate seqlens_k against rotary cache size when rotary embeddings are enabled.
// This prevents OOB access when deriving position IDs from seqlens_k during rotary embedding.
const bool is_seqlens_k_on_cpu = (seqlens_k->Location().device.Type() == OrtDevice::CPU);
if (is_seqlens_k_on_cpu) {
const int rotary_cache_max_seq = static_cast<int>(std::min(cos_cache->Shape().GetDims()[0],
sin_cache->Shape().GetDims()[0]));
const int32_t* seqlens_k_data = seqlens_k->template Data<int32_t>();
for (int b = 0; b < batch_size; b++) {
if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= rotary_cache_max_seq) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" is out of range for rotary cache dimension 0 (", rotary_cache_max_seq, ")");
}
}
}
} else if (cos_cache != nullptr || sin_cache != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
Expand Down
169 changes: 169 additions & 0 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,5 +307,174 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {
{}, nullptr, &execution_providers);
}

// Regression: seqlens_k valid for KV cache but exceeding cos_cache.shape[0] must be rejected
// when do_rotary is enabled. Without this check, the position ID derived from seqlens_k
// would index out of bounds in the cos/sin cache, leaking heap memory into output.
TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_OOB) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 16; // must be multiple of 16 for rotary
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;
constexpr int rotary_half_dim = head_size / 2; // cos/sin cache dim-1 = 8

constexpr int cos_cache_max_seq = 4; // small rotary cache
constexpr int past_seq_len = 16; // large KV cache
constexpr int seqlens_k_val = 10; // valid for KV (10 < 16) but OOB for cos (10 >= 4)
constexpr int total_seq_len = 4; // passes CheckRotaryCaches (4 <= cos_cache_max_seq)

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));

// Past KV cache is large enough for seqlens_k=10
tester.AddInput<float>("past_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));
tester.AddInput<float>("past_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));

tester.AddInput<int32_t>("seqlens_k", {1}, {seqlens_k_val});
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

// cos/sin cache with only 4 rows — seqlens_k=10 exceeds this
tester.AddInput<float>("cos_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 1.0f));
tester.AddInput<float>("sin_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 0.0f));

tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "is out of range for rotary cache dimension 0",
{}, nullptr, &execution_providers);
}

// Positive test: seqlens_k within cos/sin cache bounds with do_rotary enabled should succeed.
TEST(GroupQueryAttentionTest, SeqlensKWithinCosCache_Rotary) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 16; // must be multiple of 16 for rotary
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;
constexpr int rotary_half_dim = head_size / 2;

constexpr int cos_cache_max_seq = 16; // rotary cache large enough
constexpr int past_seq_len = 16;
constexpr int seqlens_k_val = 3; // valid: 3 < 16 (cos cache) and 3 < 16 (KV cache)
constexpr int total_seq_len = 4; // seqlens_k + 1

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));

tester.AddInput<float>("query", {1, 1, hidden_size}, std::vector<float>(hidden_size, 1.0f));
tester.AddInput<float>("key", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));

tester.AddInput<float>("past_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));
tester.AddInput<float>("past_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.5f));

tester.AddInput<int32_t>("seqlens_k", {1}, {seqlens_k_val});
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

tester.AddInput<float>("cos_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 1.0f));
tester.AddInput<float>("sin_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 0.0f));

tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

tester.AddOutput<float>("output", {1, 1, hidden_size}, std::vector<float>(hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {1, kv_num_heads, past_seq_len, head_size},
std::vector<float>(kv_num_heads * past_seq_len * head_size, 0.0f));

tester.SetOutputTolerance(1e6f); // shape acceptance test, not numerical correctness

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "",
{}, nullptr, &execution_providers);
}

// Multi-batch test: one valid and one OOB seqlens_k value.
// Verifies the validation loop correctly identifies the offending batch index.
TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 16;
constexpr int hidden_size = num_heads * head_size;
constexpr int kv_hidden_size = kv_num_heads * head_size;
constexpr int rotary_half_dim = head_size / 2;

constexpr int cos_cache_max_seq = 4;
constexpr int past_seq_len = 16;
constexpr int total_seq_len = 4;
constexpr int batch_size = 2;

OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
tester.AddAttribute<int64_t>("kv_num_heads", static_cast<int64_t>(kv_num_heads));
tester.AddAttribute<int64_t>("do_rotary", static_cast<int64_t>(1));

tester.AddInput<float>("query", {batch_size, 1, hidden_size},
std::vector<float>(batch_size * hidden_size, 1.0f));
tester.AddInput<float>("key", {batch_size, 1, kv_hidden_size},
std::vector<float>(batch_size * kv_hidden_size, 1.0f));
tester.AddInput<float>("value", {batch_size, 1, kv_hidden_size},
std::vector<float>(batch_size * kv_hidden_size, 1.0f));

tester.AddInput<float>("past_key", {batch_size, kv_num_heads, past_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f));
tester.AddInput<float>("past_value", {batch_size, kv_num_heads, past_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f));

// seqlens_k: batch 0 is valid (3 < 4), batch 1 is OOB (10 >= 4)
tester.AddInput<int32_t>("seqlens_k", {batch_size}, {3, 10});
tester.AddInput<int32_t>("total_sequence_length", {1}, {total_seq_len});

tester.AddInput<float>("cos_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 1.0f));
tester.AddInput<float>("sin_cache", {cos_cache_max_seq, rotary_half_dim},
std::vector<float>(cos_cache_max_seq * rotary_half_dim, 0.0f));

tester.AddOptionalInputEdge<int64_t>(); // position_ids
tester.AddOptionalInputEdge<float>(); // attention_bias
tester.AddOptionalInputEdge<float>(); // head_sink

tester.AddOutput<float>("output", {batch_size, 1, hidden_size},
std::vector<float>(batch_size * hidden_size, 0.0f));
tester.AddOutput<float>("present_key", {batch_size, kv_num_heads, past_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f));
tester.AddOutput<float>("present_value", {batch_size, kv_num_heads, past_seq_len, head_size},
std::vector<float>(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
// Error should reference batch index 1: seqlens_k[1] = 10
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k[1] = 10",
{}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading