Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ struct TreeEnsembleAttributesV3 {
"target_class_weights and target_class_weights_as_tensor cannot both be non-empty");
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max(),
"nodes_modes size (", nodes_modes.size(), ") exceeds uint32_t max");

int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ").");
int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids,
Comment thread
xadupre marked this conversation as resolved.
") in target_ids or class_ids is greater than the number of targets or classes (", n_targets_or_classes, ").");
}

std::string aggregate_function;
Expand Down
86 changes: 85 additions & 1 deletion onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ TEST(MLOpTest, TreeRegressorBranchEq) {
std::vector<int64_t> nodes_truenodeids = {1, -1, 4, -1, -1};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 1, 2};
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {1, 3, 4};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {10.0, 20.0, 30.0};
Expand All @@ -885,5 +885,89 @@ TEST(MLOpTest, TreeRegressorBranchEq) {
test.Run();
}

TEST(MLOpTest, TreeRegressorNegativeTargetIds) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, -1};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values");
}

TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 1};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater than the number of targets or classes");
}

} // namespace test
} // namespace onnxruntime
Loading