Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ struct TreeEnsembleAttributesV3 {
"target_class_weights and target_class_weights_as_tensor cannot both be non-empty");
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max(),
"nodes_modes size (", nodes_modes.size(), ") exceeds uint32_t max");

int64_t min_ids = *std::min_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(min_ids >= 0, "target_ids or class_ids cannot have negative values (", min_ids, ").");
int64_t max_ids = *std::max_element(target_class_ids.begin(), target_class_ids.end());
ORT_ENFORCE(max_ids < n_targets_or_classes, "At least one value (", max_ids,
Comment thread
xadupre marked this conversation as resolved.
") in target_ids or class_ids is greater or equal to the number of targets or classes (",
n_targets_or_classes, ").");
}

std::string aggregate_function;
Expand Down
86 changes: 85 additions & 1 deletion onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ TEST(MLOpTest, TreeRegressorBranchEq) {
std::vector<int64_t> nodes_truenodeids = {1, -1, 4, -1, -1};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 1, 2};
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {1, 3, 4};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {10.0, 20.0, 30.0};
Expand All @@ -885,5 +885,89 @@ TEST(MLOpTest, TreeRegressorBranchEq) {
test.Run();
}

TEST(MLOpTest, TreeRegressorNegativeTargetIds) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, -1};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "target_ids or class_ids cannot have negative values");
}

TEST(MLOpTest, TreeRegressorOutsideBoundaryTargetIds) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 1};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run(OpTester::ExpectResult::kExpectFailure, "At least one value (1) in target_ids or class_ids is greater or equal to the number of targets or classes (1)");
}

} // namespace test
} // namespace onnxruntime
Loading