diff --git a/onnxruntime/core/providers/cpu/nn/Unpool.cc b/onnxruntime/core/providers/cpu/nn/Unpool.cc index 5997dcedebfd7..b9016d979b4bb 100644 --- a/onnxruntime/core/providers/cpu/nn/Unpool.cc +++ b/onnxruntime/core/providers/cpu/nn/Unpool.cc @@ -96,15 +96,22 @@ Status MaxUnpool::Compute(OpKernelContext* context) const { } // unpool - int64_t total_elements = X_shape.Size(); + size_t total_elements = narrow(X_shape.Size()); + size_t output_size = narrow(shape.Size()); Tensor* Y = context->Output(0, shape); auto* Y_data = Y->MutableData(); - auto out = gsl::make_span(Y_data, narrow(Y->Shape().Size())); + auto out = gsl::make_span(Y_data, output_size); std::fill_n(out.data(), out.size(), 0.f); - for (auto cur_elem = 0; cur_elem < total_elements; ++cur_elem) { - out[narrow(I_data[narrow(cur_elem)])] = X_data[narrow(cur_elem)]; + for (size_t cur_elem = 0; cur_elem < total_elements; ++cur_elem) { + const int64_t idx = I_data[cur_elem]; + if (idx < 0 || idx >= static_cast(output_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Index value out of bounds. Got: ", idx, ". Valid range is [0, ", output_size, ")."); + } + + out[static_cast(idx)] = X_data[cur_elem]; } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/nn/unpool_op_test.cc b/onnxruntime/test/providers/cpu/nn/unpool_op_test.cc index a9985bdfac831..52e184ddedb40 100644 --- a/onnxruntime/test/providers/cpu/nn/unpool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/unpool_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "core/graph/constants.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -436,5 +437,32 @@ TEST(UnpoolTest, MaxUnPool_DefaultStrides) { test.Run(); } +TEST(UnpoolTest, MaxUnpoolInvalidIndices) { + OpTester test("MaxUnpool", 9); + + test.AddAttribute("strides", std::vector{2}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector t_vals = {1, 2, 3, 4}; + std::vector t_dims = {1, 1, 4}; + + std::vector i_vals = {1, 3, 4, 8}; // 8 is out of bounds + std::vector i_dims = {1, 1, 4}; + + std::vector expected_dims = {1, 1, 8}; + std::vector expected_vals = {0, 1, 0, 2, 3, 0, 4, 0}; + + std::vector inputDims = {3}; + + test.AddInput("xT", t_dims, t_vals); + test.AddInput("xI", i_dims, i_vals); + test.AddInput("output_shape", inputDims, expected_dims); + + test.AddOutput("Y", expected_dims, expected_vals); + std::vector> cpu_execution_provider; + cpu_execution_provider.push_back(DefaultCpuExecutionProvider()); + test.Run(BaseTester::ExpectResult::kExpectFailure, "Index value out of bounds", {}, nullptr, + &cpu_execution_provider); +} } // namespace test } // namespace onnxruntime