diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6ad238cc826..8fddfe5adfc9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -160,13 +160,15 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && - local_window_size_ == -1 && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size); if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Local attention UNSUPPORTED for sm < 80 on CUDA."); } + if (use_memory_efficient_attention && local_window_size_ != -1) { + parameters.local_window_size = -1; + } // allocate buffers size_t kv_buffer_bytes = 0; // need a buffer if we must ungroup kv