|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/cpu/ml/tree_ensemble.h" |
| 5 | +#include "core/providers/cpu/ml/tree_ensemble_helper.h" |
| 6 | +#include "core/common/inlined_containers_fwd.h" |
| 7 | + |
| 8 | +namespace onnxruntime { |
| 9 | +namespace ml { |
| 10 | + |
| 11 | +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( |
| 12 | + TreeEnsemble, |
| 13 | + 5, |
| 14 | + float, |
| 15 | + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()).MayInplace(0, 0), |
| 16 | + TreeEnsemble<float>); |
| 17 | + |
| 18 | +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( |
| 19 | + TreeEnsemble, |
| 20 | + 5, |
| 21 | + double, |
| 22 | + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<double>()).MayInplace(0, 0), |
| 23 | + TreeEnsemble<double>); |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +TreeEnsemble<T>::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) { |
| 27 | + if constexpr (std::is_same<T, double>::value) { |
| 28 | + p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, double>>(); |
| 29 | + } else { |
| 30 | + p_tree_ensemble_ = std::make_unique<detail::TreeEnsembleCommonV5<T, float>>(); |
| 31 | + } |
| 32 | + ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); |
| 33 | +} |
| 34 | + |
| 35 | +template <typename T> |
| 36 | +Status TreeEnsemble<T>::GetRemovableAttributes(InlinedVector<std::string>& removable_attributes) const { |
| 37 | + InlinedVector<std::string> names{ |
| 38 | + "leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs", |
| 39 | + "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true", |
| 40 | + "nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"}; |
| 41 | + removable_attributes.swap(names); |
| 42 | + return Status::OK(); |
| 43 | +} |
| 44 | + |
| 45 | +template <typename T> |
| 46 | +common::Status TreeEnsemble<T>::Compute(OpKernelContext* context) const { |
| 47 | + const auto* X = context->Input<Tensor>(0); |
| 48 | + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); |
| 49 | + if (X->Shape().NumDimensions() == 0) { |
| 50 | + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, |
| 51 | + "Input shape needs to be at least a single dimension."); |
| 52 | + } |
| 53 | + int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0]; |
| 54 | + Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()}); |
| 55 | + return p_tree_ensemble_->compute(context, X, Y, NULL); |
| 56 | +} |
| 57 | + |
| 58 | +} // namespace ml |
| 59 | +} // namespace onnxruntime |
0 commit comments