Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 54 additions & 35 deletions onnxruntime/core/providers/cpu/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
#include "core/providers/cpu/nn/lrn.h"
#include "core/providers/cpu/element_wise_ranged_transform.h"

#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
// Chance of arithmetic overflow could be reduced
#pragma warning(disable : 26451)
#endif

namespace onnxruntime {

namespace functors {
Expand All @@ -49,73 +46,95 @@ struct Powx {
}
};
} // namespace functors

template <>
Status LRN<float>::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
if (X == nullptr)
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const auto& X = context->RequiredInput<Tensor>(0);
const auto& X_shape = X.Shape();

Tensor* Y = context->Output(0, X->Shape());
Tensor* Y = context->Output(0, X_shape);

// Supports NCHW image format.
ORT_ENFORCE(X->Shape().NumDimensions() == 4);
const int N = gsl::narrow_cast<int>(X->Shape()[0]);
const int C = gsl::narrow_cast<int>(X->Shape()[1]);
const int H = gsl::narrow_cast<int>(X->Shape()[2]);
const int W = gsl::narrow_cast<int>(X->Shape()[3]);
const int image_size = C * H * W;
const int pre_pad = (size_ - 1) / 2;

const auto* Xdata = X->Data<float>();

ORT_ENFORCE(X_shape.NumDimensions() == 4);
const ptrdiff_t N = narrow<ptrdiff_t>(X_shape[0]);
const ptrdiff_t C = narrow<ptrdiff_t>(X_shape[1]);
const ptrdiff_t H = narrow<ptrdiff_t>(X_shape[2]);
const ptrdiff_t W = narrow<ptrdiff_t>(X_shape[3]);

const ptrdiff_t X_size = narrow<ptrdiff_t>(X_shape.Size());

if (X_size == 0) {
// Nothing to compute.
return Status::OK();
}

// Note: `ptrdiff_t X_size` being set successfully implies that N*C*H*W will not overflow ptrdiff_t.

const ptrdiff_t image_size = C * H * W;
const ptrdiff_t pre_pad = (size_ - 1) / 2;
const int H_times_W = SafeInt<int>(H) * W; // H_times_W is passed to math::Axpy() which takes an int.

const auto* Xdata = X.Data<float>();
auto* Ydata = Y->MutableData<float>();

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));

const int Xsize = gsl::narrow_cast<int>(X->Shape().Size());
auto sdata = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * Xsize);
void* sdata = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * X_size);
BufferUniquePtr scale_buffer(sdata, BufferDeleter(alloc));
auto* scale_data = static_cast<float*>(scale_buffer.get());
math::Set<float, CPUMathUtil>(Xsize, bias_, scale_data, &CPUMathUtil::Instance());
math::Set<float, CPUMathUtil>(X_size, bias_, scale_data, &CPUMathUtil::Instance());

const size_t padded_square_size = (static_cast<size_t>(C) + size_ - 1) * H * W;
const ptrdiff_t padded_square_size = (SafeInt<ptrdiff_t>(C) + size_ - 1) * H * W;
auto psdata = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * padded_square_size);
BufferUniquePtr padded_square_buffer(psdata, BufferDeleter(std::move(alloc)));
Comment thread
edgchen1 marked this conversation as resolved.
auto* padded_square_data = static_cast<float*>(padded_square_buffer.get());
math::Set<float, CPUMathUtil>(padded_square_size, 0.0f, padded_square_data, &CPUMathUtil::Instance());

const float alpha_over_size = alpha_ / size_;
// go through the images
for (int n = 0; n < N; ++n) {
for (ptrdiff_t n = 0; n < N; ++n) {
const ptrdiff_t n_times_image_size = n * image_size;

// compute the padded square
math::Sqr<float, CPUMathUtil>(image_size, Xdata + image_size * n, padded_square_data + pre_pad * H * W,
&CPUMathUtil::Instance());
{
const ptrdiff_t padded_square_data_offset = SafeInt<ptrdiff_t>(pre_pad) * H_times_W;
math::Sqr<float, CPUMathUtil>(image_size, Xdata + n_times_image_size, padded_square_data + padded_square_data_offset,
&CPUMathUtil::Instance());
}
// Create the first channel scale
for (int c = 0; c < size_; ++c) {
math::Axpy<float, CPUMathUtil>(H * W, alpha_over_size, padded_square_data + c * H * W,
scale_data + image_size * n, &CPUMathUtil::Instance());
for (ptrdiff_t c = 0; c < size_; ++c) {
const ptrdiff_t padded_square_data_offset = c * H_times_W;
math::Axpy<float, CPUMathUtil>(H_times_W, alpha_over_size, padded_square_data + padded_square_data_offset,
scale_data + n_times_image_size, &CPUMathUtil::Instance());
}

for (int c = 1; c < C; ++c) {
float* this_scale_slice = scale_data + n * image_size + c * H * W;
for (ptrdiff_t c = 1; c < C; ++c) {
const ptrdiff_t this_scale_offset = n * image_size + c * H_times_W;

float* this_scale_slice = scale_data + this_scale_offset;
// copy previous scale
memcpy(this_scale_slice, this_scale_slice - H * W, H * W * sizeof(float));
memcpy(this_scale_slice, this_scale_slice - H_times_W, SafeInt<size_t>(H_times_W) * sizeof(float));
// add head
math::Axpy<float, CPUMathUtil>(H * W, alpha_over_size, padded_square_data + (c + size_ - 1) * H * W,
const ptrdiff_t padded_square_data_head_offset = (SafeInt<ptrdiff_t>(c) + size_ - 1) * H_times_W;
math::Axpy<float, CPUMathUtil>(H_times_W, alpha_over_size, padded_square_data + padded_square_data_head_offset,
this_scale_slice, &CPUMathUtil::Instance());
// subtract tail
math::Axpy<float, CPUMathUtil>(H * W, -alpha_over_size, padded_square_data + (c - 1) * H * W, this_scale_slice,
&CPUMathUtil::Instance());
const ptrdiff_t padded_square_data_tail_offset = (c - 1) * H_times_W;
math::Axpy<float, CPUMathUtil>(H_times_W, -alpha_over_size, padded_square_data + padded_square_data_tail_offset,
this_scale_slice, &CPUMathUtil::Instance());
}
}

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
using T = float;
functors::Powx<T> f;
f.input1 = scale_data;
f.input2 = Xdata;
f.b = -beta_;
f.output = Ydata;
concurrency::ThreadPool::TryParallelFor(tp, static_cast<std::ptrdiff_t>(Xsize),
concurrency::ThreadPool::TryParallelFor(tp, static_cast<std::ptrdiff_t>(X_size),
{static_cast<float>(sizeof(T)), static_cast<float>(sizeof(T)), f.Cost()}, f);
return Status::OK();
}
Expand Down
20 changes: 9 additions & 11 deletions onnxruntime/core/providers/cpu/nn/lrn.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

#pragma once

#include <gsl/gsl>
#include <cstddef>
#include <cstdint>

#include "core/common/common.h"
#include "core/common/exceptions.h"
#include "core/common/narrow.h"
#include "core/framework/op_kernel.h"
Comment thread
edgchen1 marked this conversation as resolved.

namespace onnxruntime {
Expand All @@ -16,18 +17,15 @@ class LRN : public OpKernel {
public:
LRN(const OpKernelInfo& info) : OpKernel(info) {
int64_t size;
ORT_ENFORCE(info.GetAttr<int64_t>("size", &size).IsOK());
size_ = gsl::narrow_cast<int>(size);
ORT_THROW_IF_ERROR(info.GetAttr<int64_t>("size", &size));
size_ = narrow<ptrdiff_t>(size);
ORT_ENFORCE(size_ > 0);
ORT_ENFORCE(size_ % 2 == 1);
ORT_ENFORCE(info.GetAttr<float>("alpha", &alpha_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<float>("alpha", &alpha_));
ORT_ENFORCE(alpha_ > 0.0f);
ORT_ENFORCE(info.GetAttr<float>("beta", &beta_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<float>("beta", &beta_));
ORT_ENFORCE(beta_ > 0.0f);
Status status = info.GetAttr<float>("bias", &bias_);
if (!status.IsOK()) {
bias_ = 1.0f;
}
ORT_THROW_IF_ERROR(info.GetAttr<float>("bias", &bias_));
Comment thread
edgchen1 marked this conversation as resolved.
Comment thread
edgchen1 marked this conversation as resolved.
}

Status Compute(OpKernelContext* p_op_kernel_context) const override;
Expand All @@ -36,6 +34,6 @@ class LRN : public OpKernel {
float alpha_;
float beta_;
float bias_;
int size_;
ptrdiff_t size_;
};
} // namespace onnxruntime
Loading
Loading