diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc new file mode 100644 index 0000000000000..bf10ea01a8152 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +// Verify that MaxPoolGrad rejects indices that exceed the output buffer size. +TEST(MaxPoolGradTest, IndicesOutOfRange) { + OpTester test("MaxPoolGrad", 9, kOnnxDomain); + + // dY: shape [1, 1, 2, 2] + test.AddInput("dY", {1, 1, 2, 2}, {1.0f, 1.0f, 1.0f, 1.0f}); + // Indices: same shape, last value 100 is out of range [0, 9) + test.AddInput("Indices", {1, 1, 2, 2}, {0, 1, 2, 100}); + // Expected dX: shape [1, 1, 3, 3] → 9 elements + test.AddOutput("dX", {1, 1, 3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Invalid index in MaxPoolGrad: index value 100 is out of range [0, 9).", + {}, nullptr, &execution_providers); +} + +// Verify that MaxPoolGrad rejects negative indices. +TEST(MaxPoolGradTest, IndicesNegative) { + OpTester test("MaxPoolGrad", 9, kOnnxDomain); + + // dY: shape [1, 1, 2, 2] + test.AddInput("dY", {1, 1, 2, 2}, {1.0f, 1.0f, 1.0f, 1.0f}); + // Indices: same shape, value -1 is negative and out of range + test.AddInput("Indices", {1, 1, 2, 2}, {0, 1, -1, 3}); + // Expected dX: shape [1, 1, 3, 3] → 9 elements + test.AddOutput("dX", {1, 1, 3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Invalid index in MaxPoolGrad: index value -1 is out of range [0, 9).", + {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc index 769b4d1bc2bd8..1c384c2397ab1 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc @@ -54,9 +54,15 @@ Status MaxPoolGrad::Compute(OpKernelContext* context) const { EigenVectorMap(dX_data, narrow(dX_shape.Size())).setZero(); + const int64_t dX_size = dX_shape.Size(); for (int64_t i = 0; i < dY->Shape().Size(); ++i) { - T* p_dX_data = dX_data + indices_data[i]; - *p_dX_data += dY_data[i]; + int64_t index = indices_data[i]; + if (index < 0 || index >= dX_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid index in MaxPoolGrad: index value ", index, + " is out of range [0, ", dX_size, ")."); + } + dX_data[index] += dY_data[i]; } return Status::OK();