From 2fb111ebcc4f7706231ccd3a342e3e906031bded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 12:26:57 +0000 Subject: [PATCH 1/4] fix target_ids out of boundary in TreeEnsemble* --- .../cpu/ml/tree_ensemble_attribute.h | 6 ++ .../providers/cpu/ml/treeregressor_test.cc | 86 ++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index b072cf32fd342..62afa50d7af95 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -112,6 +112,12 @@ struct TreeEnsembleAttributesV3 { "target_class_weights and target_class_weights_as_tensor cannot both be non-empty"); ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max(), "nodes_modes size (", nodes_modes.size(), ") exceeds uint32_t max"); + + int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end()); + ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ")."); + int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end()); + ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids, + ") in target_ids or class_ids is greater than the number of targets or classes (", n_targets_or_classes, ")."); } std::string aggregate_function; diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index e6c9f8435ffdd..cea2949ba5f58 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -858,7 +858,7 @@ TEST(MLOpTest, TreeRegressorBranchEq) { std::vector nodes_truenodeids = {1, -1, 4, -1, -1}; std::string post_transform = "NONE"; - std::vector target_ids = {0, 1, 2}; + std::vector target_ids = {0, 0, 0}; std::vector target_nodeids = {1, 3, 4}; std::vector target_treeids = {0, 0, 0}; std::vector target_weights = {10.0, 20.0, 30.0}; @@ -885,5 +885,89 @@ TEST(MLOpTest, TreeRegressorBranchEq) { test.Run(); } +TEST(MLOpTest, TreeRegressorNegativeTargetIds) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 0, 0, 1, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 4, 0, 5.5, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0}; + std::vector nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, -1}; + std::vector target_nodeids = {3, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f}; + std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 1}, Y); + test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values"); +} + +TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 0, 0, 1, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 4, 0, 5.5, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0}; + std::vector nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 1}; + std::vector target_nodeids = {3, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f}; + std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 1}, Y); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater than the number of targets or classes"); +} + } // namespace test } // namespace onnxruntime From 07be6a1e21e82a979ba7dfc8d7c9c1cf1b2034a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 14:41:35 +0000 Subject: [PATCH 2/4] lint --- onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index 62afa50d7af95..3bff42bd4ed06 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -112,7 +112,7 @@ struct TreeEnsembleAttributesV3 { "target_class_weights and target_class_weights_as_tensor cannot both be non-empty"); ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max(), "nodes_modes size (", nodes_modes.size(), ") exceeds uint32_t max"); - + int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end()); ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ")."); int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end()); From 639d8b77ca6089000f6e62d494e99863d2762d74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 15:56:06 +0000 Subject: [PATCH 3/4] fix error message --- onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h index 3bff42bd4ed06..c3b1c52720fca 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -117,7 +117,8 @@ struct TreeEnsembleAttributesV3 { ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ")."); int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end()); ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids, - ") in target_ids or class_ids is greater than the number of targets or classes (", n_targets_or_classes, ")."); + ") in target_ids or class_ids is greater or equal to the number of targets or classes (", + n_targets_or_classes, ")."); } std::string aggregate_function; From fa4ede7af4924775e7400214a281e4c61f7bb990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 8 Apr 2026 09:51:19 +0000 Subject: [PATCH 4/4] fix error message --- onnxruntime/test/providers/cpu/ml/treeregressor_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index cea2949ba5f58..8e84f5fd4dc9e 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -966,7 +966,7 @@ TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) { std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; test.AddInput("X", {3, 2}, X); test.AddOutput("Y", {3, 1}, Y); - test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater than the number of targets or classes"); + test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater or equal to the number of targets or classes (1)"); } } // namespace test