Skip to content

Commit a2ba3cb

Browse files
xaduprebili2002cbourjau
authored
Implementation of TreeEnsemble ai.onnx.ml==5 (#22333)
### Description Merges PR #21851, #21222. Implements TreeEnsemble from ai.onnx.ml==5 (CPU). --------- Co-authored-by: Bilyana Indzheva <[email protected]> Co-authored-by: Bilyana Indzheva <[email protected]> Co-authored-by: Christian Bourjau <[email protected]>
1 parent c97dd6e commit a2ba3cb

13 files changed

+1155
-349
lines changed

docs/OperatorKernels.md

+1
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ Do not modify directly.*
453453
|SVMClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
454454
|SVMRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(float)|
455455
|Scaler|*in* X:**T**<br> *out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
456+
|TreeEnsemble|*in* X:**T**<br> *out* Y:**T**|5+|**T** = tensor(double), tensor(float)|
456457
|TreeEnsembleClassifier|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
457458
|||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int64), tensor(string)|
458459
|TreeEnsembleRegressor|*in* X:**T**<br> *out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)|

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

+6
Original file line numberDiff line numberDiff line change
@@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3,
29252925
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier);
29262926
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor);
29272927
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor);
2928+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble);
2929+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble);
29282930

29292931
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder);
29302932
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder);
@@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
30433045
TreeEnsembleRegressor)>,
30443046
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double,
30453047
TreeEnsembleRegressor)>,
3048+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float,
3049+
TreeEnsemble)>,
3050+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double,
3051+
TreeEnsemble)>,
30463052

30473053
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string,
30483054
LabelEncoder)>,

onnxruntime/core/providers/cpu/ml/ml_common.h

+31-27
Original file line numberDiff line numberDiff line change
@@ -20,44 +20,48 @@ enum class OUTPUT_MODE {
2020
ALL_SCORES
2121
};
2222

23-
enum NODE_MODE : uint8_t {
24-
LEAF = 1,
25-
BRANCH_LEQ = 2,
26-
BRANCH_LT = 4,
27-
BRANCH_GTE = 6,
28-
BRANCH_GT = 8,
29-
BRANCH_EQ = 10,
30-
BRANCH_NEQ = 12
23+
enum NODE_MODE_ONNX : uint8_t {
24+
BRANCH_LEQ = 0,
25+
BRANCH_LT = 1,
26+
BRANCH_GTE = 2,
27+
BRANCH_GT = 3,
28+
BRANCH_EQ = 4,
29+
BRANCH_NEQ = 5,
30+
BRANCH_MEMBER = 6,
31+
LEAF = 7,
3132
};
3233

33-
static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
34+
static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
3435
if (input == "BRANCH_LEQ") {
35-
return NODE_MODE::BRANCH_LEQ;
36+
return NODE_MODE_ONNX::BRANCH_LEQ;
3637
}
3738
if (input == "LEAF") {
38-
return NODE_MODE::LEAF;
39+
return NODE_MODE_ONNX::LEAF;
3940
}
4041
if (input == "BRANCH_LT") {
41-
return NODE_MODE::BRANCH_LT;
42+
return NODE_MODE_ONNX::BRANCH_LT;
4243
}
4344
if (input == "BRANCH_GTE") {
44-
return NODE_MODE::BRANCH_GTE;
45+
return NODE_MODE_ONNX::BRANCH_GTE;
4546
}
4647
if (input == "BRANCH_GT") {
47-
return NODE_MODE::BRANCH_GT;
48+
return NODE_MODE_ONNX::BRANCH_GT;
4849
}
4950
if (input == "BRANCH_EQ") {
50-
return NODE_MODE::BRANCH_EQ;
51+
return NODE_MODE_ONNX::BRANCH_EQ;
5152
}
52-
return NODE_MODE::BRANCH_NEQ;
53+
if (input == "BRANCH_MEMBER") {
54+
return NODE_MODE_ONNX::BRANCH_MEMBER;
55+
}
56+
return NODE_MODE_ONNX::BRANCH_NEQ;
5357
}
5458

55-
enum class POST_EVAL_TRANSFORM {
56-
NONE,
57-
LOGISTIC,
58-
SOFTMAX,
59-
SOFTMAX_ZERO,
60-
PROBIT
59+
enum class POST_EVAL_TRANSFORM : int64_t {
60+
NONE = 0,
61+
LOGISTIC = 1,
62+
SOFTMAX = 2,
63+
SOFTMAX_ZERO = 3,
64+
PROBIT = 4
6165
};
6266

6367
static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
@@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
7680
return POST_EVAL_TRANSFORM::PROBIT;
7781
}
7882

79-
enum class AGGREGATE_FUNCTION {
80-
AVERAGE,
81-
SUM,
82-
MIN,
83-
MAX
83+
enum class AGGREGATE_FUNCTION : int64_t {
84+
AVERAGE = 0,
85+
SUM = 1,
86+
MIN = 2,
87+
MAX = 3
8488
};
8589

8690
static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cpu/ml/tree_ensemble.h"
5+
#include "core/providers/cpu/ml/tree_ensemble_helper.h"
6+
#include "core/common/inlined_containers_fwd.h"
7+
8+
namespace onnxruntime {
9+
namespace ml {
10+
11+
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
12+
TreeEnsemble,
13+
5,
14+
float,
15+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).MayInplace(0, 0),
16+
TreeEnsemble<float>);
17+
18+
ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
19+
TreeEnsemble,
20+
5,
21+
double,
22+
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()).MayInplace(0, 0),
23+
TreeEnsemble<double>);
24+
25+
template <typename T>
26+
TreeEnsemble<T>::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) {
27+
if constexpr (std::is_same<T, double>::value) {
28+
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, double>>();
29+
} else {
30+
p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, float>>();
31+
}
32+
ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
33+
}
34+
35+
template <typename T>
36+
Status TreeEnsemble<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const {
37+
InlinedVector<std::string> names{
38+
"leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs",
39+
"nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true",
40+
"nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"};
41+
removable_attributes.swap(names);
42+
return Status::OK();
43+
}
44+
45+
template <typename T>
46+
common::Status TreeEnsemble<T>::Compute(OpKernelContext* context) const {
47+
const auto* X = context->Input<Tensor>(0);
48+
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
49+
if (X->Shape().NumDimensions() == 0) {
50+
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
51+
"Input shape needs to be at least a single dimension.");
52+
}
53+
int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
54+
Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()});
55+
return p_tree_ensemble_->compute(context, X, Y, NULL);
56+
}
57+
58+
} // namespace ml
59+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "tree_ensemble_common.h"
6+
7+
namespace onnxruntime {
8+
namespace ml {
9+
template <typename T>
10+
class TreeEnsemble final : public OpKernel {
11+
typedef T InputType; // input type
12+
typedef float OutputType; // output type
13+
public:
14+
explicit TreeEnsemble(const OpKernelInfo& info);
15+
common::Status Compute(OpKernelContext* context) const override;
16+
Status GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const override;
17+
18+
private:
19+
// Pointer on one instance of
20+
// detail::TreeEnsembleCommonV5<T, ThresholdType>
21+
// where ThresholdType is defined after accessing the attributes.
22+
std::unique_ptr<detail::TreeEnsembleCommonAttributes> p_tree_ensemble_;
23+
};
24+
} // namespace ml
25+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h

+37-3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,40 @@ union PtrOrWeight {
7878
} weight_data;
7979
};
8080

81+
enum NODE_MODE_ORT : uint8_t {
82+
LEAF = 1,
83+
BRANCH_LEQ = 2,
84+
BRANCH_LT = 4,
85+
BRANCH_GTE = 6,
86+
BRANCH_GT = 8,
87+
BRANCH_EQ = 10,
88+
BRANCH_NEQ = 12,
89+
BRANCH_MEMBER = 14,
90+
};
91+
92+
inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
93+
switch (node_mode) {
94+
case NODE_MODE_ONNX::LEAF:
95+
return NODE_MODE_ORT::LEAF;
96+
case NODE_MODE_ONNX::BRANCH_LEQ:
97+
return NODE_MODE_ORT::BRANCH_LEQ;
98+
case NODE_MODE_ONNX::BRANCH_LT:
99+
return NODE_MODE_ORT::BRANCH_LT;
100+
case NODE_MODE_ONNX::BRANCH_GTE:
101+
return NODE_MODE_ORT::BRANCH_GTE;
102+
case NODE_MODE_ONNX::BRANCH_GT:
103+
return NODE_MODE_ORT::BRANCH_GT;
104+
case NODE_MODE_ONNX::BRANCH_EQ:
105+
return NODE_MODE_ORT::BRANCH_EQ;
106+
case NODE_MODE_ONNX::BRANCH_NEQ:
107+
return NODE_MODE_ORT::BRANCH_NEQ;
108+
case NODE_MODE_ONNX::BRANCH_MEMBER:
109+
return NODE_MODE_ORT::BRANCH_MEMBER;
110+
default:
111+
ORT_THROW("Unexpected value for node_mode");
112+
};
113+
}
114+
81115
template <typename T>
82116
struct TreeNodeElement {
83117
int feature_id;
@@ -98,10 +132,10 @@ struct TreeNodeElement {
98132
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
99133
// stored in `value_or_unique_weight`.
100134
PtrOrWeight<T> truenode_or_weight;
101-
uint8_t flags;
135+
NODE_MODE_ORT flags;
102136

103-
inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
104-
inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
137+
inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
138+
inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
105139
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
106140
};
107141

0 commit comments

Comments
 (0)