Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ template <>
inline void ComputeAttentionSoftmaxInplace<MLFloat16>(MLFloat16* score, size_t N, size_t D, ThreadPool* tp, AllocatorPtr allocator) {
ORT_ENFORCE(tp == nullptr, "No parallelized version of softmax for float16.");
// Mlas Lacks kernels for fp16 softmax, we convert into float32 and call the float32 version.
const auto num_elements = SafeInt<size_t>(N) * D;
void* allocated_ptr = allocator->Alloc(num_elements * sizeof(float));
const size_t buffer_bytes = detail::Fp16SoftmaxTempBufferBytes(N, D);
void* allocated_ptr = allocator->Alloc(buffer_bytes);
BufferUniquePtr float_buffer(allocated_ptr, BufferDeleter(allocator));
float* ptr = reinterpret_cast<float*>(allocated_ptr);
MlasConvertHalfToFloatBuffer(score, ptr, num_elements);
MlasConvertHalfToFloatBuffer(score, ptr, N * D);
MlasComputeSoftmax(ptr, ptr, N, D, false, false, 0.0f, tp);
MlasConvertFloatToHalfBuffer(ptr, score, num_elements);
MlasConvertFloatToHalfBuffer(ptr, score, N * D);
}

template <typename T>
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/llm/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,25 @@

#pragma once
#include "core/common/common.h"
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/llm/attention_parameters.h"
#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"

namespace onnxruntime {

namespace detail {

// Returns the number of bytes needed for the FP16 softmax temporary float buffer
// (used when converting FP16 scores to float32 before calling MlasComputeSoftmax).
// Uses SafeInt<size_t> to prevent integer overflow for large values of N*D.
inline size_t Fp16SoftmaxTempBufferBytes(size_t n, size_t d) {
return SafeInt<size_t>(n) * d * sizeof(float);
}

} // namespace detail

// This value is used to mask out a value from the input as ``Softmax(-infinity, ...) = 0``.
// If the mask is added, -infinity + x = -infinity.
// inifinity is replaced by lowest() because softmax implemented in MLAS
Expand Down
79 changes: 22 additions & 57 deletions onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
#include <limits>
#include "gtest/gtest.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/cpu/llm/attention.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/system_info.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -2408,66 +2408,31 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) {
}

// Regression test for CPU kernel integer overflow in FP16 softmax allocation.
// ComputeAttentionSoftmaxInplace<MLFloat16> previously used int for N and D.
// For large enough values of N and D, N * D could overflow int32.
TEST(AttentionTest, AttentionCpuFp16SoftmaxLargeDimensions) {
// Skip if the machine has less than 16GB of physical RAM.
constexpr uint64_t required_ram_bytes = 16ULL * 1024 * 1024 * 1024;
if (const auto total_ram_bytes = GetTotalPhysicalMemoryBytes();
total_ram_bytes.has_value() && *total_ram_bytes < required_ram_bytes) {
GTEST_SKIP() << "Skipping: test requires >= 16GB RAM, machine has "
<< (*total_ram_bytes / (1024 * 1024)) << "MB";
}

constexpr int batch_size = 1;
constexpr int num_heads = 1;
constexpr int q_sequence_length = 46341;
constexpr int kv_sequence_length = 46341;
constexpr int head_size = 1;

// Verify at compile time that these dimensions trigger the overflow scenario.
static_assert(static_cast<int64_t>(q_sequence_length) * kv_sequence_length >
// ComputeAttentionSoftmaxInplace<MLFloat16> previously used int for N and D,
// so N * D could overflow int32 for large sequence lengths.
// The fix uses detail::Fp16SoftmaxTempBufferBytes (see attention.h) which
// applies SafeInt<size_t> to prevent overflow. This test exercises that helper
// directly, so no large tensor allocations or op execution are required.
TEST(AttentionTest, AttentionCpuFp16SoftmaxBufferSizeNoOverflow) {
// N*D exceeds INT_MAX (46341^2 > 2^31-1), which is the same scenario that
// previously caused int32 overflow in ComputeAttentionSoftmaxInplace<MLFloat16>.
constexpr size_t N = 46341; // sqrt(INT_MAX) ~ 46340, so N*D > INT_MAX
constexpr size_t D = 46341;

static_assert(static_cast<int64_t>(N) * static_cast<int64_t>(D) >
static_cast<int64_t>(std::numeric_limits<int>::max()),
"Test dimensions must cause int32 overflow in N*D");

OpTester test("Attention", 23, onnxruntime::kOnnxDomain);

// 4D BNSH inputs
std::vector<int64_t> q_shape = {batch_size, num_heads, q_sequence_length, head_size};
std::vector<int64_t> k_shape = {batch_size, num_heads, kv_sequence_length, head_size};
std::vector<int64_t> v_shape = {batch_size, num_heads, kv_sequence_length, head_size};

constexpr int q_elements = batch_size * num_heads * q_sequence_length * head_size;
constexpr int kv_elements = batch_size * num_heads * kv_sequence_length * head_size;

// All-zero Q and K → all attention scores are 0, softmax produces uniform 1/kv_seq.
// All-one V → output is also all 1.0 (weighted average of 1s).
std::vector<float> q_data(q_elements, 0.0f);
std::vector<float> k_data(kv_elements, 0.0f);
std::vector<float> v_data(kv_elements, 1.0f);

test.AddInput<MLFloat16>("Q", q_shape, ToFloat16(q_data));
test.AddInput<MLFloat16>("K", k_shape, ToFloat16(k_data));
test.AddInput<MLFloat16>("V", v_shape, ToFloat16(v_data));
test.AddOptionalInputEdge<bool>(); // attn_mask
test.AddOptionalInputEdge<MLFloat16>(); // past_key
test.AddOptionalInputEdge<MLFloat16>(); // past_value

// Expected output: all 1.0 (uniform attention over all-ones V).
std::vector<int64_t> y_shape = {batch_size, num_heads, q_sequence_length, head_size};
std::vector<float> expected_y(q_elements, 1.0f);
test.AddOutput<MLFloat16>("Y", y_shape, ToFloat16(expected_y), false, 0, 3e-2f);
test.AddOptionalOutputEdge<MLFloat16>(); // present_key
test.AddOptionalOutputEdge<MLFloat16>(); // present_value

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());

if constexpr (sizeof(void*) == 4) {
// Expect overflow for 32-bit builds.
test.Run(OpTester::ExpectResult::kExpectFailure, "Integer overflow", {}, nullptr, &execution_providers);
// sizeof(void*) == 8 on all common 64-bit platforms (LP64 and LLP64/Windows).
// On 32-bit platforms (sizeof(void*) == 4), size_t is also 32-bit, so
// N*D*sizeof(float) overflows and SafeInt throws.
if constexpr (sizeof(void*) >= 8) {
// On 64-bit builds the computation must succeed and equal N * D * sizeof(float).
const size_t bytes = detail::Fp16SoftmaxTempBufferBytes(N, D);
EXPECT_EQ(bytes, N * D * sizeof(float));
} else {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
// On 32-bit builds the multiplication overflows size_t; SafeInt must throw.
EXPECT_THROW(detail::Fp16SoftmaxTempBufferBytes(N, D), OnnxRuntimeException);
}
}

Expand Down