diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 5698bcb659f20..5a4f6795865c0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -144,6 +144,9 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true"); + // Validation of seqlens_k against rotary cache size is now performed in CheckInputs() + // to ensure all execution providers (CPU, CUDA, etc.) get the protection in one place. + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f5399e307fbca..bfb454dfc05cf 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -284,6 +284,22 @@ Status CheckInputs(const T* query, int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim)); + + // Validate seqlens_k against rotary cache size when rotary embeddings are enabled. + // This prevents OOB access when deriving position IDs from seqlens_k during rotary embedding. + const bool is_seqlens_k_on_cpu = (seqlens_k->Location().device.Type() == OrtDevice::CPU); + if (is_seqlens_k_on_cpu) { + const int rotary_cache_max_seq = static_cast(std::min(cos_cache->Shape().GetDims()[0], + sin_cache->Shape().GetDims()[0])); + const int32_t* seqlens_k_data = seqlens_k->template Data(); + for (int b = 0; b < batch_size; b++) { + if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= rotary_cache_max_seq) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k[", b, "] = ", seqlens_k_data[b], + " is out of range for rotary cache dimension 0 (", rotary_cache_max_seq, ")"); + } + } + } } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 0690094031bb8..eb025683e1813 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -307,5 +307,174 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { {}, nullptr, &execution_providers); } +// Regression: seqlens_k valid for KV cache but exceeding cos_cache.shape[0] must be rejected +// when do_rotary is enabled. Without this check, the position ID derived from seqlens_k +// would index out of bounds in the cos/sin cache, leaking heap memory into output. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_OOB) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; // cos/sin cache dim-1 = 8 + + constexpr int cos_cache_max_seq = 4; // small rotary cache + constexpr int past_seq_len = 16; // large KV cache + constexpr int seqlens_k_val = 10; // valid for KV (10 < 16) but OOB for cos (10 >= 4) + constexpr int total_seq_len = 4; // passes CheckRotaryCaches (4 <= cos_cache_max_seq) + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + // Past KV cache is large enough for seqlens_k=10 + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + // cos/sin cache with only 4 rows — seqlens_k=10 exceeds this + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, "is out of range for rotary cache dimension 0", + {}, nullptr, &execution_providers); +} + +// Positive test: seqlens_k within cos/sin cache bounds with do_rotary enabled should succeed. +TEST(GroupQueryAttentionTest, SeqlensKWithinCosCache_Rotary) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 16; // rotary cache large enough + constexpr int past_seq_len = 16; + constexpr int seqlens_k_val = 3; // valid: 3 < 16 (cos cache) and 3 < 16 (KV cache) + constexpr int total_seq_len = 4; // seqlens_k + 1 + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {1, 1, hidden_size}, std::vector(hidden_size, 1.0f)); + tester.AddInput("key", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.5f)); + + tester.AddInput("seqlens_k", {1}, {seqlens_k_val}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {1, 1, hidden_size}, std::vector(hidden_size, 0.0f)); + tester.AddOutput("present_key", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {1, kv_num_heads, past_seq_len, head_size}, + std::vector(kv_num_heads * past_seq_len * head_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // shape acceptance test, not numerical correctness + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", + {}, nullptr, &execution_providers); +} + +// Multi-batch test: one valid and one OOB seqlens_k value. +// Verifies the validation loop correctly identifies the offending batch index. +TEST(GroupQueryAttentionTest, SeqlensKExceedsCosCache_MultiBatch) { + constexpr int num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int rotary_half_dim = head_size / 2; + + constexpr int cos_cache_max_seq = 4; + constexpr int past_seq_len = 16; + constexpr int total_seq_len = 4; + constexpr int batch_size = 2; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + tester.AddInput("query", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 1.0f)); + tester.AddInput("key", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + tester.AddInput("value", {batch_size, 1, kv_hidden_size}, + std::vector(batch_size * kv_hidden_size, 1.0f)); + + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.5f)); + + // seqlens_k: batch 0 is valid (3 < 4), batch 1 is OOB (10 >= 4) + tester.AddInput("seqlens_k", {batch_size}, {3, 10}); + tester.AddInput("total_sequence_length", {1}, {total_seq_len}); + + tester.AddInput("cos_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 1.0f)); + tester.AddInput("sin_cache", {cos_cache_max_seq, rotary_half_dim}, + std::vector(cos_cache_max_seq * rotary_half_dim, 0.0f)); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + tester.AddOutput("output", {batch_size, 1, hidden_size}, + std::vector(batch_size * hidden_size, 0.0f)); + tester.AddOutput("present_key", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + std::vector(batch_size * kv_num_heads * past_seq_len * head_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + // Error should reference batch index 1: seqlens_k[1] = 10 + tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k[1] = 10", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime