Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Set Membership in TreeEnsemble #21222

Closed
wants to merge 11 commits into from
Prev Previous commit
Next Next commit
Fix windows error
bili2002 committed Sep 18, 2024
commit 1b15a1d529ae9e3c483dfdf5c310f0d7287aeea9
37 changes: 22 additions & 15 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
@@ -119,6 +119,17 @@ std::enable_if_t<
return dst;
}

template <typename T>
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
if constexpr (sizeof(T) == sizeof(uint32_t)) {
return bit_cast<uint32_t>(val);
}
else if constexpr (sizeof(T) == sizeof(uint64_t)) {
return bit_cast<uint64_t>(val);
}
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
}

template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
@@ -376,11 +387,8 @@ bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAr
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
const auto left_tree_node = node_tree_ids[left_id];
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(left_tree_node, uint32_t(0)))->second;

const auto right_tree_node = node_tree_ids[right_id];
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(right_tree_node, uint32_t(0)))->second;
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
@@ -439,7 +447,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
}

node.value_or_unique_weight = 0;
const auto node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_MEMBER;
@@ -452,7 +460,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
auto falsenode_id = falsenode_ids[i];
size_t falsenode_id = falsenode_ids[i];

// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
@@ -461,7 +469,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_MEMBER) {
auto falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
@@ -783,16 +791,15 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
} \
}



// Check whether the feature value is set true in the mask
inline bool SetMembershipCheck(double val, double mask) {
const auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, double) && (((1ll << (val_as_int - 1)) & bit_cast<uint64_t>(mask)) != 0);
template <typename T1, typename T2>
inline bool SetMembershipCheck(T1 val, T2 mask) {
const int64_t val_as_int = static_cast<int64_t>(val);
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
}

inline bool SetMembershipCheck(float val, float mask) {
const auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, float) && (((1ll << (val_as_int - 1)) & bit_cast<uint32_t>(mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }