Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
ORT_ENFORCE(cos_cache != nullptr && sin_cache != nullptr, "cos_cache and sin_cache must be provided when do_rotary is true");
// Validate seqlens_k values against cos_cache size to prevent OOB in rotary embedding lookup.
{
const int cos_cache_max_seq = static_cast<int>(cos_cache->Shape().GetDims()[0]);
Comment thread
vraspar marked this conversation as resolved.
Outdated
const int32_t* seqlens_k_data = seqlens_k->Data<int32_t>();
for (int b = 0; b < batch_size; b++) {
// position_id = seqlens_k[b] (in token generation), must be < cos_cache rows
if (seqlens_k_data[b] >= cos_cache_max_seq) {
Comment thread
apsonawane marked this conversation as resolved.
Outdated
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" exceeds cos_cache dimension 0 (", cos_cache_max_seq, ")");
Comment thread
apsonawane marked this conversation as resolved.
Outdated
}
}
}
Comment thread
apsonawane marked this conversation as resolved.
Outdated
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
Expand Down
Loading