Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,49 @@
target_class_nodeids = info.GetAttrsOrDefault<int64_t>("target_nodeids");
target_class_treeids = info.GetAttrsOrDefault<int64_t>("target_treeids");
target_class_weights = info.GetAttrsOrDefault<float>("target_weights");

ORT_ENFORCE(n_targets_or_classes > 0);
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
nodes_falsenodeids.size() == nodes_values_as_tensor.size());
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits<uint32_t>::max());
}

ORT_ENFORCE(n_targets_or_classes > 0,
"n_targets_or_classes must be positive, got ", n_targets_or_classes);
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size(),
"nodes_falsenodeids and nodes_featureids must have the same size, got ",
nodes_falsenodeids.size(), " and ", nodes_featureids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size(),
"nodes_falsenodeids and nodes_modes must have the same size, got ",
nodes_falsenodeids.size(), " and ", nodes_modes.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size(),
"nodes_falsenodeids and nodes_nodeids must have the same size, got ",
nodes_falsenodeids.size(), " and ", nodes_nodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size(),
"nodes_falsenodeids and nodes_treeids must have the same size, got ",
nodes_falsenodeids.size(), " and ", nodes_treeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size(),
"nodes_falsenodeids and nodes_truenodeids must have the same size, got ",
nodes_falsenodeids.size(), " and ", nodes_truenodeids.size());
ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
nodes_falsenodeids.size() == nodes_values_as_tensor.size(),
Comment thread
vraspar marked this conversation as resolved.
Outdated
"nodes_falsenodeids size (", nodes_falsenodeids.size(),
") must match nodes_values (", nodes_values.size(),
") or nodes_values_as_tensor (", nodes_values_as_tensor.size(), ")");
ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size(),
"target_class_ids and target_class_nodeids must have the same size, got ",
target_class_ids.size(), " and ", target_class_nodeids.size());
ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size(),
"target_class_ids and target_class_treeids must have the same size, got ",
target_class_ids.size(), " and ", target_class_treeids.size());
ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size(),
"target_class_weights must be empty or match target_class_ids size, got ",
target_class_weights.size(), " and ", target_class_ids.size());
ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty(),
"base_values and base_values_as_tensor cannot both be non-empty");
ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty(),
"nodes_hitrates and nodes_hitrates_as_tensor cannot both be non-empty");
ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty(),
"nodes_values and nodes_values_as_tensor cannot both be non-empty");
ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty(),
"target_class_weights and target_class_weights_as_tensor cannot both be non-empty");
ORT_ENFORCE(nodes_modes.size() < std::numeric_limits<uint32_t>::max(),

Check warning on line 113 in onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h:113: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
"nodes_modes size (", nodes_modes.size(), ") exceeds uint32_t max");
}

std::string aggregate_function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,5 +367,87 @@ TEST(MLOpTest, TreeEnsembleClassifierBinaryProbabilities) {
test.Run();
}

TEST(MLOpTest, TreeEnsembleClassifierMismatchedClassArrays) {
OpTester test("TreeEnsembleClassifier", 1, onnxruntime::kMLDomain);

// Use a minimal valid tree: one root (BRANCH_LEQ) with two leaf children.
std::vector<int64_t> lefts = {1, -1, -1};
std::vector<int64_t> rights = {2, -1, -1};
std::vector<int64_t> treeids = {0, 0, 0};
std::vector<int64_t> nodeids = {0, 1, 2};
std::vector<int64_t> featureids = {0, -2, -2};
std::vector<float> thresholds = {0.5f, -2.f, -2.f};
std::vector<std::string> modes = {"BRANCH_LEQ", "LEAF", "LEAF"};
Comment thread
vraspar marked this conversation as resolved.

// Intentionally mismatched: class_nodeids has fewer elements than class_ids.
std::vector<int64_t> class_treeids = {0, 0};
std::vector<int64_t> class_nodeids = {1}; // only 1 element — mismatch!
std::vector<int64_t> class_classids = {0, 1};
std::vector<float> class_weights = {1.f, 1.f};
std::vector<int64_t> classes = {0, 1};

test.AddAttribute("nodes_truenodeids", lefts);
test.AddAttribute("nodes_falsenodeids", rights);
test.AddAttribute("nodes_treeids", treeids);
test.AddAttribute("nodes_nodeids", nodeids);
test.AddAttribute("nodes_featureids", featureids);
test.AddAttribute("nodes_values", thresholds);
test.AddAttribute("nodes_modes", modes);
test.AddAttribute("class_treeids", class_treeids);
test.AddAttribute("class_nodeids", class_nodeids);
test.AddAttribute("class_ids", class_classids);
test.AddAttribute("class_weights", class_weights);
test.AddAttribute("classlabels_int64s", classes);

std::vector<float> X = {1.f};
test.AddInput<float>("X", {1, 1}, X);
test.AddOutput<int64_t>("Y", {1}, {0});
test.AddOutput<float>("Z", {1, 2}, {0.f, 0.f});

test.Run(OpTester::ExpectResult::kExpectFailure,
"target_class_ids and target_class_nodeids must have the same size");
}


Comment thread
vraspar marked this conversation as resolved.
Outdated
TEST(MLOpTest, TreeEnsembleClassifierMismatchedNodeArrays) {
OpTester test("TreeEnsembleClassifier", 1, onnxruntime::kMLDomain);

// nodes_falsenodeids has 3 elements, but nodes_featureids has only 2 — mismatch!
std::vector<int64_t> lefts = {1, -1, -1};
std::vector<int64_t> rights = {2, -1, -1};
std::vector<int64_t> treeids = {0, 0, 0};
std::vector<int64_t> nodeids = {0, 1, 2};
std::vector<int64_t> featureids = {0, -2}; // only 2 elements — mismatch!
std::vector<float> thresholds = {0.5f, -2.f, -2.f};
std::vector<std::string> modes = {"BRANCH_LEQ", "LEAF", "LEAF"};

std::vector<int64_t> class_treeids = {0, 0};
std::vector<int64_t> class_nodeids = {1, 2};
std::vector<int64_t> class_classids = {0, 1};
std::vector<float> class_weights = {1.f, 1.f};
std::vector<int64_t> classes = {0, 1};

test.AddAttribute("nodes_truenodeids", lefts);
test.AddAttribute("nodes_falsenodeids", rights);
test.AddAttribute("nodes_treeids", treeids);
test.AddAttribute("nodes_nodeids", nodeids);
test.AddAttribute("nodes_featureids", featureids);
test.AddAttribute("nodes_values", thresholds);
test.AddAttribute("nodes_modes", modes);
test.AddAttribute("class_treeids", class_treeids);
test.AddAttribute("class_nodeids", class_nodeids);
test.AddAttribute("class_ids", class_classids);
test.AddAttribute("class_weights", class_weights);
test.AddAttribute("classlabels_int64s", classes);

std::vector<float> X = {1.f};
test.AddInput<float>("X", {1, 1}, X);
test.AddOutput<int64_t>("Y", {1}, {0});
test.AddOutput<float>("Z", {1, 2}, {0.f, 0.f});

test.Run(OpTester::ExpectResult::kExpectFailure,
"nodes_falsenodeids and nodes_featureids must have the same size");
}

} // namespace test
} // namespace onnxruntime
Loading