From 2d909374efb4c72c9be39f925efb9ae8f4399ba5 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 1 Apr 2026 13:23:37 -0700 Subject: [PATCH 1/2] Fix heap out-of-bounds write in MaxPoolGrad via unchecked indices Add bounds validation for index values in MaxPoolGrad::Compute to prevent heap out-of-bounds writes when the indices tensor contains values outside the valid range [0, output_size). Each index is now validated with ORT_ENFORCE before being used as an offset into the output buffer. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cpu/nn/pool_gradient_op_test.cc | 48 +++++++++++++++++++ .../training_ops/cpu/nn/pool_gradient_op.cc | 8 +++- 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc new file mode 100644 index 0000000000000..bf10ea01a8152 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cpu/nn/pool_gradient_op_test.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +// Verify that MaxPoolGrad rejects indices that exceed the output buffer size. +TEST(MaxPoolGradTest, IndicesOutOfRange) { + OpTester test("MaxPoolGrad", 9, kOnnxDomain); + + // dY: shape [1, 1, 2, 2] + test.AddInput("dY", {1, 1, 2, 2}, {1.0f, 1.0f, 1.0f, 1.0f}); + // Indices: same shape, last value 100 is out of range [0, 9) + test.AddInput("Indices", {1, 1, 2, 2}, {0, 1, 2, 100}); + // Expected dX: shape [1, 1, 3, 3] → 9 elements + test.AddOutput("dX", {1, 1, 3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Invalid index in MaxPoolGrad: index value 100 is out of range [0, 9).", + {}, nullptr, &execution_providers); +} + +// Verify that MaxPoolGrad rejects negative indices. +TEST(MaxPoolGradTest, IndicesNegative) { + OpTester test("MaxPoolGrad", 9, kOnnxDomain); + + // dY: shape [1, 1, 2, 2] + test.AddInput("dY", {1, 1, 2, 2}, {1.0f, 1.0f, 1.0f, 1.0f}); + // Indices: same shape, value -1 is negative and out of range + test.AddInput("Indices", {1, 1, 2, 2}, {0, 1, -1, 3}); + // Expected dX: shape [1, 1, 3, 3] → 9 elements + test.AddOutput("dX", {1, 1, 3, 3}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Invalid index in MaxPoolGrad: index value -1 is out of range [0, 9).", + {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc index 769b4d1bc2bd8..51ce1f13f58c3 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc @@ -54,9 +54,13 @@ Status MaxPoolGrad::Compute(OpKernelContext* context) const { EigenVectorMap(dX_data, narrow(dX_shape.Size())).setZero(); + const int64_t dX_size = dX_shape.Size(); for (int64_t i = 0; i < dY->Shape().Size(); ++i) { - T* p_dX_data = dX_data + indices_data[i]; - *p_dX_data += dY_data[i]; + int64_t index = indices_data[i]; + ORT_ENFORCE(index >= 0 && index < dX_size, + "Invalid index in MaxPoolGrad: index value ", index, + " is out of range [0, ", dX_size, ")."); + dX_data[index] += dY_data[i]; } return Status::OK(); From a9e95be0f6ac987df25507e6d344a88e32cc8cc2 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 1 Apr 2026 14:03:25 -0700 Subject: [PATCH 2/2] Address Copilot review: use ORT_MAKE_STATUS instead of ORT_ENFORCE Return INVALID_ARGUMENT status instead of aborting to avoid DoS when built with ORT_NO_EXCEPTIONS, where ORT_ENFORCE maps to abort(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../orttraining/training_ops/cpu/nn/pool_gradient_op.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc index 51ce1f13f58c3..1c384c2397ab1 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.cc @@ -57,9 +57,11 @@ Status MaxPoolGrad::Compute(OpKernelContext* context) const { const int64_t dX_size = dX_shape.Size(); for (int64_t i = 0; i < dY->Shape().Size(); ++i) { int64_t index = indices_data[i]; - ORT_ENFORCE(index >= 0 && index < dX_size, - "Invalid index in MaxPoolGrad: index value ", index, - " is out of range [0, ", dX_size, ")."); + if (index < 0 || index >= dX_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid index in MaxPoolGrad: index value ", index, + " is out of range [0, ", dX_size, ")."); + } dX_data[index] += dY_data[i]; }