diff --git a/onnxruntime/core/providers/cpu/nn/lrn.cc b/onnxruntime/core/providers/cpu/nn/lrn.cc index 218f083330312..e5c9acdb7eb96 100644 --- a/onnxruntime/core/providers/cpu/nn/lrn.cc +++ b/onnxruntime/core/providers/cpu/nn/lrn.cc @@ -18,14 +18,11 @@ #include "core/providers/cpu/nn/lrn.h" #include "core/providers/cpu/element_wise_ranged_transform.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -// TODO: fix the warnings -#if defined(_MSC_VER) && !defined(__clang__) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif + namespace onnxruntime { namespace functors { @@ -49,36 +46,47 @@ struct Powx { } }; } // namespace functors + template <> Status LRN::Compute(OpKernelContext* context) const { - const auto* X = context->Input(0); - if (X == nullptr) - return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const auto& X = context->RequiredInput(0); + const auto& X_shape = X.Shape(); - Tensor* Y = context->Output(0, X->Shape()); + Tensor* Y = context->Output(0, X_shape); // Supports NCHW image format. - ORT_ENFORCE(X->Shape().NumDimensions() == 4); - const int N = gsl::narrow_cast(X->Shape()[0]); - const int C = gsl::narrow_cast(X->Shape()[1]); - const int H = gsl::narrow_cast(X->Shape()[2]); - const int W = gsl::narrow_cast(X->Shape()[3]); - const int image_size = C * H * W; - const int pre_pad = (size_ - 1) / 2; - - const auto* Xdata = X->Data(); + + ORT_ENFORCE(X_shape.NumDimensions() == 4); + const ptrdiff_t N = narrow(X_shape[0]); + const ptrdiff_t C = narrow(X_shape[1]); + const ptrdiff_t H = narrow(X_shape[2]); + const ptrdiff_t W = narrow(X_shape[3]); + + const ptrdiff_t X_size = narrow(X_shape.Size()); + + if (X_size == 0) { + // Nothing to compute. + return Status::OK(); + } + + // Note: `ptrdiff_t X_size` being set successfully implies that N*C*H*W will not overflow ptrdiff_t. + + const ptrdiff_t image_size = C * H * W; + const ptrdiff_t pre_pad = (size_ - 1) / 2; + const int H_times_W = SafeInt(H) * W; // H_times_W is passed to math::Axpy() which takes an int. + + const auto* Xdata = X.Data(); auto* Ydata = Y->MutableData(); AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); - const int Xsize = gsl::narrow_cast(X->Shape().Size()); - auto sdata = alloc->Alloc(SafeInt(sizeof(float)) * Xsize); + void* sdata = alloc->Alloc(SafeInt(sizeof(float)) * X_size); BufferUniquePtr scale_buffer(sdata, BufferDeleter(alloc)); auto* scale_data = static_cast(scale_buffer.get()); - math::Set(Xsize, bias_, scale_data, &CPUMathUtil::Instance()); + math::Set(X_size, bias_, scale_data, &CPUMathUtil::Instance()); - const size_t padded_square_size = (static_cast(C) + size_ - 1) * H * W; + const ptrdiff_t padded_square_size = (SafeInt(C) + size_ - 1) * H * W; auto psdata = alloc->Alloc(SafeInt(sizeof(float)) * padded_square_size); BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(std::move(alloc))); auto* padded_square_data = static_cast(padded_square_buffer.get()); @@ -86,28 +94,39 @@ Status LRN::Compute(OpKernelContext* context) const { const float alpha_over_size = alpha_ / size_; // go through the images - for (int n = 0; n < N; ++n) { + for (ptrdiff_t n = 0; n < N; ++n) { + const ptrdiff_t n_times_image_size = n * image_size; + // compute the padded square - math::Sqr(image_size, Xdata + image_size * n, padded_square_data + pre_pad * H * W, - &CPUMathUtil::Instance()); + { + const ptrdiff_t padded_square_data_offset = SafeInt(pre_pad) * H_times_W; + math::Sqr(image_size, Xdata + n_times_image_size, padded_square_data + padded_square_data_offset, + &CPUMathUtil::Instance()); + } // Create the first channel scale - for (int c = 0; c < size_; ++c) { - math::Axpy(H * W, alpha_over_size, padded_square_data + c * H * W, - scale_data + image_size * n, &CPUMathUtil::Instance()); + for (ptrdiff_t c = 0; c < size_; ++c) { + const ptrdiff_t padded_square_data_offset = c * H_times_W; + math::Axpy(H_times_W, alpha_over_size, padded_square_data + padded_square_data_offset, + scale_data + n_times_image_size, &CPUMathUtil::Instance()); } - for (int c = 1; c < C; ++c) { - float* this_scale_slice = scale_data + n * image_size + c * H * W; + for (ptrdiff_t c = 1; c < C; ++c) { + const ptrdiff_t this_scale_offset = n * image_size + c * H_times_W; + + float* this_scale_slice = scale_data + this_scale_offset; // copy previous scale - memcpy(this_scale_slice, this_scale_slice - H * W, H * W * sizeof(float)); + memcpy(this_scale_slice, this_scale_slice - H_times_W, SafeInt(H_times_W) * sizeof(float)); // add head - math::Axpy(H * W, alpha_over_size, padded_square_data + (c + size_ - 1) * H * W, + const ptrdiff_t padded_square_data_head_offset = (SafeInt(c) + size_ - 1) * H_times_W; + math::Axpy(H_times_W, alpha_over_size, padded_square_data + padded_square_data_head_offset, this_scale_slice, &CPUMathUtil::Instance()); // subtract tail - math::Axpy(H * W, -alpha_over_size, padded_square_data + (c - 1) * H * W, this_scale_slice, - &CPUMathUtil::Instance()); + const ptrdiff_t padded_square_data_tail_offset = (c - 1) * H_times_W; + math::Axpy(H_times_W, -alpha_over_size, padded_square_data + padded_square_data_tail_offset, + this_scale_slice, &CPUMathUtil::Instance()); } } + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); using T = float; functors::Powx f; @@ -115,7 +134,7 @@ Status LRN::Compute(OpKernelContext* context) const { f.input2 = Xdata; f.b = -beta_; f.output = Ydata; - concurrency::ThreadPool::TryParallelFor(tp, static_cast(Xsize), + concurrency::ThreadPool::TryParallelFor(tp, static_cast(X_size), {static_cast(sizeof(T)), static_cast(sizeof(T)), f.Cost()}, f); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/lrn.h b/onnxruntime/core/providers/cpu/nn/lrn.h index dc27672aa056d..56bae291c023a 100644 --- a/onnxruntime/core/providers/cpu/nn/lrn.h +++ b/onnxruntime/core/providers/cpu/nn/lrn.h @@ -3,10 +3,11 @@ #pragma once -#include +#include +#include #include "core/common/common.h" -#include "core/common/exceptions.h" +#include "core/common/narrow.h" #include "core/framework/op_kernel.h" namespace onnxruntime { @@ -16,18 +17,15 @@ class LRN : public OpKernel { public: LRN(const OpKernelInfo& info) : OpKernel(info) { int64_t size; - ORT_ENFORCE(info.GetAttr("size", &size).IsOK()); - size_ = gsl::narrow_cast(size); + ORT_THROW_IF_ERROR(info.GetAttr("size", &size)); + size_ = narrow(size); ORT_ENFORCE(size_ > 0); ORT_ENFORCE(size_ % 2 == 1); - ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttr("alpha", &alpha_)); ORT_ENFORCE(alpha_ > 0.0f); - ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK()); + ORT_THROW_IF_ERROR(info.GetAttr("beta", &beta_)); ORT_ENFORCE(beta_ > 0.0f); - Status status = info.GetAttr("bias", &bias_); - if (!status.IsOK()) { - bias_ = 1.0f; - } + ORT_THROW_IF_ERROR(info.GetAttr("bias", &bias_)); } Status Compute(OpKernelContext* p_op_kernel_context) const override; @@ -36,6 +34,6 @@ class LRN : public OpKernel { float alpha_; float beta_; float bias_; - int size_; + ptrdiff_t size_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/lrn_op_test.cc b/onnxruntime/test/providers/cpu/nn/lrn_op_test.cc index 87348a2ec6ed9..0d60b1088c3ad 100644 --- a/onnxruntime/test/providers/cpu/nn/lrn_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/lrn_op_test.cc @@ -1,14 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "gtest/gtest.h" #include "default_providers.h" #include "test/common/dnnl_op_test_utils.h" #include "test/providers/provider_test_utils.h" + using namespace std; namespace onnxruntime { namespace test { +// Compute reference LRN output using the ONNX formula: +// Y[n,c,h,w] = X[n,c,h,w] * (bias + alpha/size * sum(X[n,j,h,w]^2))^(-beta) +// where j ranges over [max(0, c - floor(size/2)), min(C-1, c + floor(size/2))]. +// Input shape must be NCHW. +static vector ComputeLRNReference(const vector& X, + int64_t N, int64_t C, int64_t H, int64_t W, + float alpha, float beta, float bias, int64_t size) { + const int64_t total = N * C * H * W; + const int64_t pre_pad = (size - 1) / 2; + vector expected(static_cast(total)); + for (int64_t n = 0; n < N; ++n) { + for (int64_t c = 0; c < C; ++c) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t w = 0; w < W; ++w) { + float sum_sq = 0.0f; + for (int64_t j = std::max(int64_t{0}, c - pre_pad); j <= std::min(C - 1, c + pre_pad); ++j) { + float val = X[n * C * H * W + j * H * W + h * W + w]; + sum_sq += val * val; + } + float scale = bias + (alpha / size) * sum_sq; + int64_t idx = n * C * H * W + c * H * W + h * W + w; + expected[idx] = X[idx] * std::pow(scale, -beta); + } + } + } + } + return expected; +} + TEST(LRNTest, LRN_1) { OpTester test("LRN"); test.AddAttribute("alpha", .001f); @@ -104,6 +137,168 @@ TEST(LRNTest, LRN_2) { test.Run(); } +// Test that size > C is handled correctly (window is clamped to valid channel range). +TEST(LRNTest, SizeGreaterThanChannels) { + constexpr float alpha = 0.001f; + constexpr float beta = 0.75f; + constexpr float bias = 1.0f; + constexpr int64_t size = 5; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + // N=1, C=3, H=2, W=2 with size=5 > C=3 + vector X = {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}; + vector shape = {1, 3, 2, 2}; + + vector expected = ComputeLRNReference(X, 1, 3, 2, 2, alpha, beta, bias, size); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + +// Test with minimum valid size (size=3) where size == C for a 3-channel input. +TEST(LRNTest, SizeEqualsChannels) { + constexpr float alpha = 0.0001f; + constexpr float beta = 0.75f; + constexpr float bias = 1.0f; + constexpr int64_t size = 3; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + // N=1, C=3, H=1, W=1 + vector X = {1.0f, 2.0f, 3.0f}; + vector shape = {1, 3, 1, 1}; + + vector expected = ComputeLRNReference(X, 1, 3, 1, 1, alpha, beta, bias, size); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + +// Test with larger spatial dimensions to verify correctness with non-trivial H and W. +TEST(LRNTest, LargerSpatialDims) { + constexpr float alpha = 0.001f; + constexpr float beta = 0.75f; + constexpr float bias = 1.0f; + constexpr int64_t size = 3; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + constexpr int64_t N = 1, C = 3, H = 128, W = 128; + constexpr int64_t total = N * C * H * W; + vector X(total); + // Fill with a simple pattern + for (int64_t i = 0; i < total; ++i) { + X[i] = static_cast(i % 7) * 0.1f + 0.1f; + } + vector shape = {N, C, H, W}; + + vector expected = ComputeLRNReference(X, N, C, H, W, alpha, beta, bias, size); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + +// Test with multiple batch items (N > 1) to cover the outer loop with n * image_size arithmetic. +TEST(LRNTest, MultipleBatches) { + constexpr float alpha = 0.01f; + constexpr float beta = 0.5f; + constexpr float bias = 1.0f; + constexpr int64_t size = 3; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + constexpr int64_t N = 2, C = 3, H = 2, W = 2; + constexpr int64_t total = N * C * H * W; + vector X(total); + for (int64_t i = 0; i < total; ++i) { + X[i] = static_cast(i + 1) * 0.05f; + } + vector shape = {N, C, H, W}; + + vector expected = ComputeLRNReference(X, N, C, H, W, alpha, beta, bias, size); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + +// Test with more channels than size to exercise the sliding window (add head / subtract tail) path. +TEST(LRNTest, ManyChannels) { + constexpr float alpha = 0.0001f; + constexpr float beta = 0.75f; + constexpr float bias = 1.0f; + constexpr int64_t size = 3; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + // C > size to exercise the c=1..C-1 loop with head/tail updates + constexpr int64_t N = 1, C = 8, H = 2, W = 2; + constexpr int64_t total = N * C * H * W; + vector X(total); + for (int64_t i = 0; i < total; ++i) { + X[i] = static_cast((i % 5) + 1) * 0.2f; + } + vector shape = {N, C, H, W}; + + vector expected = ComputeLRNReference(X, N, C, H, W, alpha, beta, bias, size); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + +// Test with all-zero input -- edge case where squared values are all zero. +TEST(LRNTest, ZeroInput) { + constexpr float alpha = 0.001f; + constexpr float beta = 0.75f; + constexpr float bias = 1.0f; + constexpr int64_t size = 3; + + OpTester test("LRN"); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + test.AddAttribute("bias", bias); + test.AddAttribute("size", size); + + constexpr int64_t N = 1, C = 3, H = 2, W = 2; + constexpr int64_t total = N * C * H * W; + vector X(total, 0.0f); + vector shape = {N, C, H, W}; + + // With all zeros: scale = bias = 1.0, Y = 0 * pow(1.0, -beta) = 0 + vector expected(total, 0.0f); + + test.AddInput("X", shape, X); + test.AddOutput("Y", shape, expected); + test.Run(); +} + #if defined(USE_DNNL) TEST(LRNTest, LRN_bfloat16_1) { #ifdef USE_DNNL