Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Squ
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Tile);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Unsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Upsample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Upsample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
Expand Down Expand Up @@ -366,8 +366,8 @@ void RegisterOnnxOperatorKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, Tile)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Transpose)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Unsqueeze)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, float, Upsample)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, int32_t, Upsample)>());
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>());
fn(BuildKernel<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, Scan)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>());
Expand Down
34 changes: 24 additions & 10 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ using namespace ::onnxruntime::common;
using namespace std;
namespace onnxruntime {

ONNX_CPU_OPERATOR_TYPED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Upsample,
7,
7, 9,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Upsample<float>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
Upsample,
7,
7, 9,
int32_t,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<int32_t>()),
Upsample<int32_t>);
Expand Down Expand Up @@ -191,24 +191,24 @@ void upsampleBilinear(
}

template <typename T>
Status Upsample<T>::Compute(OpKernelContext* context) const {
Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<float>& scales) const {
const Tensor* X = context->Input<Tensor>(0);
ONNXRUNTIME_ENFORCE(X != nullptr);

const std::vector<int64_t>& dims = X->Shape().GetDims();
if (dims.size() != scales_.size()) {
if (dims.size() != scales.size()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales.");
}

std::vector<int64_t> Y_dims;
for (std::size_t i = 0; i < dims.size(); i++) {
Y_dims.push_back(static_cast<int64_t>(scales_[i] * dims[i]));
Y_dims.push_back(static_cast<int64_t>(scales[i] * dims[i]));
}
Tensor* Y = context->Output(0, Y_dims);

switch (mode_) {
case UpsampleMode::NN:
return upsampleNearest<T>(X->template Data<T>(), Y->template MutableData<T>(), X->Shape(), Y->Shape(), scales_);
return upsampleNearest<T>(X->template Data<T>(), Y->template MutableData<T>(), X->Shape(), Y->Shape(), scales);
case UpsampleMode::LINEAR: {
//What's the correct behavior of linear mode is not clear right now,
//Only support bilinear with 4D tensor to keep consistent with previous behavior
Expand All @@ -219,13 +219,27 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
const int64_t input_height = dims[2], input_width = dims[3];

upsampleBilinear(batch_size, num_channels, input_height, input_width,
scales_[2], scales_[3], X->template Data<T>(), Y->template MutableData<T>());
scales[2], scales[3], X->template Data<T>(), Y->template MutableData<T>());
return Status::OK();
//return upsampleLiner<T>(X->template Data<T>(), Y->template MutableData<T>(), X->Shape(), Y->Shape(), scales_);
}
default:
return Status(ONNXRUNTIME, FAIL, "Upsample: unexpected mode");
}
}


template <typename T>
Status Upsample<T>::Compute(OpKernelContext* context) const {
if (OpKernel::Node().InputDefs().size() == 1 || scales_cached_) {
return BaseCompute(context, scales_);
}

const Tensor* scales = context->Input<Tensor>(1);
ONNXRUNTIME_ENFORCE(scales != nullptr);
int64_t scales_size = scales->Shape().Size();
std::vector<float> scales_arrary(scales_size);
ParseScalesData(scales, scales_arrary);
return BaseCompute(context, scales_arrary);
}

} // namespace onnxruntime
59 changes: 43 additions & 16 deletions onnxruntime/core/providers/cpu/tensor/upsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,25 @@ namespace onnxruntime {
constexpr const char* UpsampleModeNN = "nearest";
constexpr const char* UpsampleModeLinear = "linear";


enum UpsampleMode {
NN = 0, // nearest neighbour
LINEAR = 1, // linear interpolation
};
enum UpsampleMode {
NN = 0, // nearest neighbour
LINEAR = 1, // linear interpolation
};

class UpsampleBase {
protected:
UpsampleBase(OpKernelInfo info) {
std::string mode;
ONNXRUNTIME_ENFORCE(info.GetAttr<std::string>("mode", &mode).IsOK());

mode_ = StringToUpsampleMode(mode);

ONNXRUNTIME_ENFORCE(info.GetAttrs<float>("scales", scales_).IsOK());
for (auto& scale : scales_) {
ONNXRUNTIME_ENFORCE(scale >= 1, "Scale value should be greater than or equal to 1.");
}

if (UpsampleMode::LINEAR == mode_) {
ONNXRUNTIME_ENFORCE(((scales_[0] == 1) && (scales_[1] == 1)),
"Upsample: linear mode upsample only support bilinear, the first 2 scales should be 1.");
if (info.GetInputCount() == 1) {
ONNXRUNTIME_ENFORCE(info.GetAttrs<float>("scales", scales_).IsOK());
ScalesValidation(scales_, mode_);
}
}

UpsampleMode mode_;

std::vector<float> scales_;

UpsampleMode StringToUpsampleMode(const std::string& mode) {
Expand All @@ -49,15 +41,50 @@ class UpsampleBase {
UpsampleModeNN + "(default) or " + UpsampleModeLinear + ".");
}
}

void ScalesValidation(const std::vector<float>& scales, const UpsampleMode mode) const {
for (auto& scale : scales) {
ONNXRUNTIME_ENFORCE(scale >= 1, "Scale value should be greater than or equal to 1.");
}

if (UpsampleMode::LINEAR == mode) {
ONNXRUNTIME_ENFORCE(scales.size() == 4, "Upsample: linear mode upsample only support bilinear with 4 dimension.");
ONNXRUNTIME_ENFORCE(((scales[0] == 1) && (scales[1] == 1)),
"Upsample: linear mode upsample only support bilinear, the first 2 scales should be 1.");
}
}
};

template <typename T>
class Upsample : public UpsampleBase, public OpKernel {
public:
Upsample(OpKernelInfo info) : UpsampleBase(info), OpKernel(info) {
Upsample(OpKernelInfo info) : UpsampleBase(info), OpKernel(info), scales_cached_(false) {
if (info.GetInputCount() > 1) {
const Tensor* scale;
bool get_scale = info.TryGetConstantInput(1, &scale);

if (get_scale) {
ParseScalesData(scale, scales_);
scales_cached_ = true;
}
}
}

Status Compute(OpKernelContext* context) const override;

Status BaseCompute(OpKernelContext* context, const std::vector<float>& scales) const;

private:
void ParseScalesData(const Tensor* scale, std::vector<float>& scales) const {
const float* scale_data = scale->template Data<float>();
int64_t scales_size = scale->Shape().Size();
ONNXRUNTIME_ENFORCE(scales_size > 0, "scales size should be greater than 0.");
memcpy(scales.data(), scale_data, scales_size * sizeof(float));
ScalesValidation(scales, mode_);
}

private:
bool scales_cached_;
};

} // namespace onnxruntime
2 changes: 0 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/upsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
namespace onnxruntime {
namespace cuda {

struct TVMState;

template <typename T>
class Upsample : public UpsampleBase, public CudaKernel {
public:
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,5 +319,36 @@ TEST(UpsampleOpTest, UpsampleOpNearestTest_1D) {
test.AddOutput<float>("Y", {10}, Y);
test.Run();
}

TEST(UpsampleOpTest, UpsampleOpNearest2XTest_opset9) {
OpTester test("Upsample", 9);

std::vector<float> scales{1.0f, 1.0f, 2.0f, 2.0f};
test.AddAttribute("mode", "nearest");

const int64_t N = 1, C = 2, H = 2, W = 2;
std::vector<int32_t> X = {1, 3,
3, 5,

3, 5,
7, 9};

test.AddInput<int32_t>("X", {N, C, H, W}, X);
test.AddInput<float>("scales", {4}, scales);

std::vector<int32_t> Y = {
1, 1, 3, 3,
1, 1, 3, 3,
3, 3, 5, 5,
3, 3, 5, 5,

3, 3, 5, 5,
3, 3, 5, 5,
7, 7, 9, 9,
7, 7, 9, 9};

test.AddOutput<int32_t>("Y", {N, C, (int64_t)(H * scales[2]), (int64_t)(W * scales[3])}, Y);
test.Run();
}
} // namespace test
} // namespace onnxruntime