From b6773d9cf02bb6d5e41e382fa09873a344dc6071 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 6 Apr 2026 10:23:48 -0700 Subject: [PATCH] Add bounds validation for LinearClassifier coefficients to prevent OOB read in GEMM --- .../core/providers/cpu/ml/linearclassifier.cc | 14 +++++++ .../providers/cpu/ml/linearclassifer_test.cc | 42 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/onnxruntime/core/providers/cpu/ml/linearclassifier.cc b/onnxruntime/core/providers/cpu/ml/linearclassifier.cc index d256d60b766a8..e181ffb0c731b 100644 --- a/onnxruntime/core/providers/cpu/ml/linearclassifier.cc +++ b/onnxruntime/core/providers/cpu/ml/linearclassifier.cc @@ -36,6 +36,12 @@ LinearClassifier::LinearClassifier(const OpKernelInfo& info) using_strings_ = !classlabels_strings_.empty(); class_count_ = static_cast(intercepts_.size()); + + ORT_ENFORCE(class_count_ > 0, "LinearClassifier: intercepts must not be empty."); + ORT_ENFORCE(coefficients_.size() % static_cast(class_count_) == 0, + "LinearClassifier: coefficients size (", coefficients_.size(), + ") must be a multiple of the number of classes (", class_count_, ")."); + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); } @@ -146,6 +152,14 @@ Status LinearClassifier::Compute(OpKernelContext* ctx) const { input_shape[0]) : narrow(input_shape[1]); + // Validate coefficients are large enough to prevent OOB read in GEMM. + const size_t expected_coefficients_size = SafeInt(class_count_) * SafeInt(num_features); + if (coefficients_.size() < expected_coefficients_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "LinearClassifier: coefficients length (", coefficients_.size(), + ") is less than classes (", class_count_, ") * features (", num_features, ")"); + } + Tensor* Y = ctx->Output(0, {num_batches}); int64_t output_classes = class_count_; diff --git a/onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc b/onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc index 3decb3f93888a..7fe75c68ccb2c 100644 --- a/onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc +++ b/onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc @@ -166,5 +166,47 @@ TEST(MLOpTest, LinearClassifierMulticlassInt32Input) { TEST(MLOpTest, LinearClassifierMulticlassDoubleInput) { LinearClassifierMulticlass(); } + +// Regression test: coefficients size doesn't match class_count * num_features. +TEST(MLOpTest, LinearClassifierInvalidCoefficientsSizeFails) { + OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain); + + // 3 intercepts => class_count = 3, input has 2 features => expects 6 coefficients. + std::vector coefficients = {-0.22562418f, 0.34188559f, 0.68346153f}; + std::vector classes = {1, 2, 3}; + std::vector intercepts = {-3.91601811f, 0.42575697f, 0.13731251f}; + + test.AddAttribute("coefficients", coefficients); + test.AddAttribute("intercepts", intercepts); + test.AddAttribute("classlabels_ints", classes); + + test.AddInput("X", {1, 2}, {1.f, 0.f}); + test.AddOutput("Y", {1}, {0LL}); + test.AddOutput("Z", {1, 3}, {0.f, 0.f, 0.f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "LinearClassifier: coefficients length (3) is less than classes (3) * features (2)"); +} + +// Regression test: coefficients not divisible by class_count. +TEST(MLOpTest, LinearClassifierCoefficientsSizeNotDivisibleByClassCountFails) { + OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain); + + // 3 intercepts => class_count = 3, but 5 coefficients is not divisible by 3. + std::vector coefficients = {1.f, 2.f, 3.f, 4.f, 5.f}; + std::vector classes = {1, 2, 3}; + std::vector intercepts = {0.1f, 0.2f, 0.3f}; + + test.AddAttribute("coefficients", coefficients); + test.AddAttribute("intercepts", intercepts); + test.AddAttribute("classlabels_ints", classes); + + test.AddInput("X", {1, 2}, {1.f, 0.f}); + test.AddOutput("Y", {1}, {0LL}); + test.AddOutput("Z", {1, 3}, {0.f, 0.f, 0.f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "coefficients size (5) must be a multiple of the number of classes (3)"); +} } // namespace test } // namespace onnxruntime