From ab12ca9c54cc06c1a8a64c8928a9c3b110112b80 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 11 Dec 2024 19:10:55 -0500 Subject: [PATCH] FEA Implement pruning using honest subsample data to fit the leaves (#286) * Honest pruning in honest tree classifier * TST turn off sklearn multi_output tag assignment --------- Signed-off-by: Adam Li Co-authored-by: Haoyin Xu --- doc/whats_new/v0.10.rst | 4 + examples/calibration/plot_honest_tree.py | 117 +++++ .../calibration/plot_overlapping_gaussians.py | 2 + treeple/__init__.py | 2 +- treeple/ensemble/_honest_forest.py | 17 +- treeple/experimental/tests/test_sdf.py | 3 + treeple/meson.build | 2 +- treeple/neighbors.py | 2 +- treeple/stats/forest.py | 2 +- treeple/stats/permuteforest.py | 53 ++- treeple/tests/test_honest_forest.py | 3 + treeple/tests/test_multiview_forest.py | 7 + treeple/tests/test_supervised_forest.py | 12 +- treeple/tests/test_unsupervised_forest.py | 8 +- treeple/tree/_honest_tree.py | 126 ++++- treeple/tree/honesty/__init__.py | 0 treeple/tree/honesty/_honest_prune.pxd | 62 +++ treeple/tree/honesty/_honest_prune.pyx | 429 ++++++++++++++++++ treeple/tree/honesty/meson.build | 23 + treeple/tree/meson.build | 2 +- treeple/tree/tests/meson.build | 1 + treeple/tree/tests/test_honest_prune.py | 72 +++ treeple/tree/tests/test_honest_tree.py | 13 +- treeple/tree/tests/test_multiview.py | 5 + treeple/tree/tests/test_tree.py | 7 + treeple/tree/tests/test_unsupervised_tree.py | 4 +- 26 files changed, 931 insertions(+), 47 deletions(-) create mode 100644 examples/calibration/plot_honest_tree.py create mode 100644 treeple/tree/honesty/__init__.py create mode 100644 treeple/tree/honesty/_honest_prune.pxd create mode 100644 treeple/tree/honesty/_honest_prune.pyx create mode 100644 treeple/tree/honesty/meson.build create mode 100644 treeple/tree/tests/test_honest_prune.py diff --git a/doc/whats_new/v0.10.rst b/doc/whats_new/v0.10.rst index 8e56f544f..65db76f41 100644 --- a/doc/whats_new/v0.10.rst +++ b/doc/whats_new/v0.10.rst @@ -17,6 +17,10 @@ Changelog ``bottleneck`` library for faster computation. By `Ryan Hausen`_ (:pr:`#306`) - |Feature| Added a sparse implementation of `treeple.stats.forest.build_colemen_forest` that uses the `scipy.sparse` module. By `Ryan Hausen`_ (:pr:`#317`) +- |Feature| :class:`treeple.tree.HonestTreeClassifier` now has a ``honest_method`` parameter + that enables the user to turn on pruning of the tree, such that there are no + empty leaf predictions. This brings the model closer to the implementation in GRF in R. + By `Adam Li`_ (:pr:`#286`) Code and Documentation Contributors diff --git a/examples/calibration/plot_honest_tree.py b/examples/calibration/plot_honest_tree.py new file mode 100644 index 000000000..963946ff1 --- /dev/null +++ b/examples/calibration/plot_honest_tree.py @@ -0,0 +1,117 @@ +""" +=========================================== +Comparison of Decision Tree and Honest Tree +=========================================== + +This example compares the :class:`treeple.tree.HonestTreeClassifier` from the +``treeple`` library with the :class:`sklearn.tree.DecisionTreeClassifier` +from scikit-learn on the Iris dataset. + +Both classifiers are fitted on the same dataset and their decision trees +are plotted side by side. +""" + +import matplotlib.pyplot as plt +from sklearn import config_context +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.tree import DecisionTreeClassifier, plot_tree + +from treeple.tree import HonestTreeClassifier + +# Load the iris dataset +iris = load_iris() +X, y = iris.data, iris.target +X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=0) + +# Initialize classifiers +max_features = 0.3 + +dishonest_clf = HonestTreeClassifier( + honest_method=None, + max_features=max_features, + random_state=0, + honest_prior="ignore", +) +honest_noprune_clf = HonestTreeClassifier( + honest_method="apply", + max_features=max_features, + random_state=0, + honest_prior="ignore", +) +honest_clf = HonestTreeClassifier(honest_method="prune", max_features=max_features, random_state=0) +sklearn_clf = DecisionTreeClassifier(max_features=max_features, random_state=0) + +# Fit classifiers +dishonest_clf.fit(X_train, y_train) +honest_noprune_clf.fit(X_train, y_train) +honest_clf.fit(X_train, y_train) +sklearn_clf.fit(X_train, y_train) + +# Plotting the trees +fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(15, 5)) + +# .. note:: We skip parameter validation because internally the `plot_tree` +# function checks if the estimator is a DecisionTreeClassifier +# instance from scikit-learn, but the ``HonestTreeClassifier`` is +# a subclass of a forked version of the DecisionTreeClassifier. + +# Plot HonestTreeClassifier tree +ax = axes[2] +with config_context(skip_parameter_validation=True): + plot_tree(honest_clf, filled=True, ax=ax) +ax.set_title("HonestTreeClassifier") + +# Plot HonestTreeClassifier tree +ax = axes[1] +with config_context(skip_parameter_validation=True): + plot_tree(honest_noprune_clf, filled=False, ax=ax) +ax.set_title("HonestTreeClassifier (No pruning)") + +# Plot HonestTreeClassifier tree +ax = axes[0] +with config_context(skip_parameter_validation=True): + plot_tree(dishonest_clf, filled=False, ax=ax) +ax.set_title("HonestTreeClassifier (Dishonest)") + + +# Plot scikit-learn DecisionTreeClassifier tree +plot_tree(sklearn_clf, filled=True, ax=axes[3]) +axes[3].set_title("DecisionTreeClassifier") + +plt.show() + +# %% +# Discussion +# ---------- +# The HonestTreeClassifier is a variant of the DecisionTreeClassifier that +# provides honest inference. The honest inference is achieved by splitting the +# dataset into two parts: the training set and the validation set. The training +# set is used to build the tree, while the validation set is used to fit the +# leaf nodes for posterior prediction. This results in calibrated posteriors +# (see :ref:`sphx_glr_auto_examples_calibration_plot_overlapping_gaussians.py`). +# +# Compared to the ``honest_prior='apply'`` method, the ``honest_prior='prune'`` +# method builds a tree that will not contain empty leaves, and also leverages +# the validation set to check split conditions. Thus we see that the pruned +# honest tree is significantly smaller than the regular decision tree. + +# %% +# Evaluate predictions of the trees +# --------------------------------- +# When we do not prune, note that the honest tree will have empty leaves +# that predict the prior. In this case, ``honest_prior='ignore'`` is used +# to ignore these leaves when computing the posteriors, which will result +# in a posterior that is ``np.nan``. + +# this is the same as a decision tree classifier that is trained on less data +print("\nDishonest posteriors: ", dishonest_clf.predict_proba(X_val)) + +# this is the honest tree with empty leaves that predict the prior +print("\nHonest tree without pruning: ", honest_noprune_clf.predict_proba(X_val)) + +# this is the honest tree that is pruned +print("\nHonest tree with pruning: ", honest_clf.predict_proba(X_val)) + +# this is a regular decision tree classifier from sklearn +print("\nDTC: ", sklearn_clf.predict_proba(X_val)) diff --git a/examples/calibration/plot_overlapping_gaussians.py b/examples/calibration/plot_overlapping_gaussians.py index 2da37fdaf..034ccac73 100644 --- a/examples/calibration/plot_overlapping_gaussians.py +++ b/examples/calibration/plot_overlapping_gaussians.py @@ -1,4 +1,6 @@ """ +.. _plot_overlapping_gaussians: + =================================================================== Plot honest forest calibrations on overlapping gaussian simulations =================================================================== diff --git a/treeple/__init__.py b/treeple/__init__.py index dafad7deb..989e6d58d 100644 --- a/treeple/__init__.py +++ b/treeple/__init__.py @@ -4,7 +4,7 @@ import os import sys -__version__ = "0.9.0dev0" +__version__ = "0.10.0dev0" logger = logging.getLogger(__name__) diff --git a/treeple/ensemble/_honest_forest.py b/treeple/ensemble/_honest_forest.py index 0650d9a31..5a996c3b8 100644 --- a/treeple/ensemble/_honest_forest.py +++ b/treeple/ensemble/_honest_forest.py @@ -270,6 +270,11 @@ class HonestForestClassifier(ForestClassifier, ForestClassifierMixin): Fraction of training samples used for estimates in the trees. The remaining samples will be used to learn the tree structure. A larger fraction creates shallower trees with lower variance estimates. + + honest_method : {"prune", "apply"}, default="prune" + Method for enforcing honesty. If "prune", the tree is pruned to enforce + honesty. If "apply", the tree is not pruned, but the leaf estimates are + adjusted to enforce honesty. tree_estimator : object, default=None Instantiated tree of type BaseDecisionTree from treeple. @@ -410,6 +415,13 @@ class labels (multi-output problem). _parameter_constraints: dict = { **ForestClassifier._parameter_constraints, + **HonestTreeClassifier._parameter_constraints, + "class_weight": [ + StrOptions({"balanced_subsample", "balanced"}), + dict, + list, + None, + ], } _parameter_constraints.pop("max_samples") _parameter_constraints["max_samples"] = [ @@ -453,6 +465,7 @@ def __init__( max_samples=None, honest_prior="ignore", honest_fraction=0.5, + honest_method="apply", tree_estimator=None, stratify=False, **tree_estimator_params, @@ -475,6 +488,7 @@ def __init__( "tree_estimator", "honest_fraction", "honest_prior", + "honest_method", "stratify", ), bootstrap=bootstrap, @@ -498,6 +512,7 @@ def __init__( self.ccp_alpha = ccp_alpha self.honest_fraction = honest_fraction self.honest_prior = honest_prior + self.honest_method = honest_method self.tree_estimator = tree_estimator self.stratify = stratify self._tree_estimator_params = tree_estimator_params @@ -730,7 +745,7 @@ def oob_samples_(self): def __sklearn_tags__(self): # XXX: nans should be supportable in HRF tags = super().__sklearn_tags__() - tags.classifier_tags.multi_output = False + # tags.classifier_tags.multi_output = False tags.input_tags.allow_nan = False return tags diff --git a/treeple/experimental/tests/test_sdf.py b/treeple/experimental/tests/test_sdf.py index 64c7e45a8..ee7b9eacc 100644 --- a/treeple/experimental/tests/test_sdf.py +++ b/treeple/experimental/tests/test_sdf.py @@ -115,6 +115,9 @@ def test_sklearn_compatible_estimator(estimator, check): # XXX: can include this "generalization" in the future if it's useful if check.func.__name__ in [ "check_class_weight_classifiers", + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", ]: pytest.skip() check(estimator) diff --git a/treeple/meson.build b/treeple/meson.build index 4801d0536..5afd0c433 100644 --- a/treeple/meson.build +++ b/treeple/meson.build @@ -96,7 +96,7 @@ cc = meson.get_compiler('c') # 'source_fname', # numpy_nodepr_api) -# XXX: ENABLE WHEN DEBUGGING +# TODO XXX: ENABLE WHEN DEBUGGING boundscheck = 'False' scikit_learn_cython_args = [ diff --git a/treeple/neighbors.py b/treeple/neighbors.py index b16e732f9..c16aa3f9e 100644 --- a/treeple/neighbors.py +++ b/treeple/neighbors.py @@ -11,7 +11,7 @@ from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix -class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin): +class NearestNeighborsMetaEstimator(MetaEstimatorMixin, BaseEstimator): """Meta-estimator for nearest neighbors. Uses a decision-tree, or forest model to compute distances between samples diff --git a/treeple/stats/forest.py b/treeple/stats/forest.py index 8588282ba..6e9ed5168 100644 --- a/treeple/stats/forest.py +++ b/treeple/stats/forest.py @@ -289,7 +289,7 @@ def build_oob_forest( # the Histogram Gradient Boosting Tree does, where the binning thresholds # are passed into the tree itself, thus allowing us to set the node feature # value thresholds within the tree itself. - if est.max_bins is not None: + if hasattr(est, "max_bins") and est.max_bins is not None: X = est._bin_data(X, is_training_data=False).astype(DTYPE) # Assign chunk of trees to jobs diff --git a/treeple/stats/permuteforest.py b/treeple/stats/permuteforest.py index e5c2c3f6f..69129677b 100644 --- a/treeple/stats/permuteforest.py +++ b/treeple/stats/permuteforest.py @@ -179,6 +179,11 @@ class PermutationHonestForestClassifier(HonestForestClassifier): remaining samples will be used to learn the tree structure. A larger fraction creates shallower trees with lower variance estimates. + honest_method : {"prune", "apply"}, default="prune" + Method for enforcing honesty. If "prune", the tree is pruned to enforce + honesty. If "apply", the tree is not pruned, but the leaf estimates are + adjusted to enforce honesty. + tree_estimator : object, default=None Type of decision tree classifier to use. By default `None`, which defaults to `treeple.tree.DecisionTreeClassifier`. Note @@ -298,35 +303,37 @@ def __init__( max_samples=None, honest_prior="empirical", honest_fraction=0.5, + honest_method="apply", tree_estimator=None, stratify=False, permute_per_tree=False, **tree_estimator_params, ): super().__init__( - n_estimators, - criterion, - splitter, - max_depth, - min_samples_split, - min_samples_leaf, - min_weight_fraction_leaf, - max_features, - max_leaf_nodes, - min_impurity_decrease, - bootstrap, - oob_score, - n_jobs, - random_state, - verbose, - warm_start, - class_weight, - ccp_alpha, - max_samples, - honest_prior, - honest_fraction, - tree_estimator, - stratify, + n_estimators=n_estimators, + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + min_impurity_decrease=min_impurity_decrease, + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + class_weight=class_weight, + ccp_alpha=ccp_alpha, + max_samples=max_samples, + honest_prior=honest_prior, + honest_fraction=honest_fraction, + honest_method=honest_method, + tree_estimator=tree_estimator, + stratify=stratify, **tree_estimator_params, ) self.permute_per_tree = permute_per_tree diff --git a/treeple/tests/test_honest_forest.py b/treeple/tests/test_honest_forest.py index 7a6ee4568..abb5a25cf 100644 --- a/treeple/tests/test_honest_forest.py +++ b/treeple/tests/test_honest_forest.py @@ -310,6 +310,9 @@ def test_sklearn_compatible_estimator(estimator, check): # for fitting the tree's splits if check.func.__name__ in [ "check_class_weight_classifiers", + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", # TODO: this is an error. Somehow a segfault is raised when fit is called first and # then partial_fit "check_fit_score_takes_y", diff --git a/treeple/tests/test_multiview_forest.py b/treeple/tests/test_multiview_forest.py index dff44a863..52d7d86ea 100644 --- a/treeple/tests/test_multiview_forest.py +++ b/treeple/tests/test_multiview_forest.py @@ -18,6 +18,13 @@ ] ) def test_sklearn_compatible_estimator(estimator, check): + if check.func.__name__ in [ + # sample weights do not necessarily imply a sample is not used in clustering + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", + ]: + pytest.skip() check(estimator) diff --git a/treeple/tests/test_supervised_forest.py b/treeple/tests/test_supervised_forest.py index ba5b29577..017c0c028 100644 --- a/treeple/tests/test_supervised_forest.py +++ b/treeple/tests/test_supervised_forest.py @@ -200,7 +200,17 @@ def test_sklearn_compatible_estimator(estimator, check): ObliqueRandomForestClassifier, PatchObliqueRandomForestClassifier, ), - ) and check.func.__name__ in ["check_fit_score_takes_y"]: + ) and check.func.__name__ in [ + "check_fit_score_takes_y", + ]: + pytest.skip() + + if check.func.__name__ in [ + # sample weights do not necessarily imply a sample is not used in clustering + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", + ]: pytest.skip() check(estimator) diff --git a/treeple/tests/test_unsupervised_forest.py b/treeple/tests/test_unsupervised_forest.py index 96c6f3b17..0c186e138 100644 --- a/treeple/tests/test_unsupervised_forest.py +++ b/treeple/tests/test_unsupervised_forest.py @@ -33,9 +33,11 @@ def test_sklearn_compatible_estimator(estimator, check): if check.func.__name__ in [ # Cannot apply agglomerative clustering on < 2 samples "check_methods_subset_invariance", - # # sample weights do not necessarily imply a sample is not used in clustering - "check_sample_weights_invariance", - # # sample order is not preserved in predict + # sample weights do not necessarily imply a sample is not used in clustering + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", + # sample order is not preserved in predict "check_methods_sample_order_invariance", ]: pytest.skip() diff --git a/treeple/tree/_honest_tree.py b/treeple/tree/_honest_tree.py index 26f5b7ca5..1bb0c7cd2 100644 --- a/treeple/tree/_honest_tree.py +++ b/treeple/tree/_honest_tree.py @@ -1,15 +1,32 @@ -# Adopted from: https://github.com/neurodata/honest-forests - +from copy import copy +from numbers import Integral import numpy as np -from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context, clone +from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context, clone, is_classifier from sklearn.model_selection import StratifiedShuffleSplit from sklearn.utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions from sklearn.utils.multiclass import _check_partial_fit_first_call, check_classification_targets -from sklearn.utils.validation import check_is_fitted, check_X_y +from sklearn.utils.validation import check_is_fitted, check_random_state, check_X_y -from .._lib.sklearn.tree import DecisionTreeClassifier +from .._lib.sklearn.tree import DecisionTreeClassifier, _criterion, _tree from .._lib.sklearn.tree._classes import BaseDecisionTree +from .._lib.sklearn.tree._criterion import BaseCriterion +from .._lib.sklearn.tree._tree import Tree +from .honesty._honest_prune import HonestPruner, _build_pruned_tree_honesty + +CRITERIA_CLF = { + "gini": _criterion.Gini, + "log_loss": _criterion.Entropy, + "entropy": _criterion.Entropy, +} +CRITERIA_REG = { + "squared_error": _criterion.MSE, + "friedman_mse": _criterion.FriedmanMSE, + "absolute_error": _criterion.MAE, + "poisson": _criterion.Poisson, +} + +DOUBLE = _tree.DOUBLE class HonestTreeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseDecisionTree): @@ -173,6 +190,13 @@ class frequency in the voting subsample. Whether or not to stratify sample when considering structure and leaf indices. By default False. + honest_method : {"apply", "prune"}, default="apply" + Method to use for fitting the leaf nodes. If "apply", the leaf nodes + are fit using the structure as is. In this case, empty leaves may occur + if not enough data. If "prune", the leaf nodes are fit + by pruning using the honest-set of data after the tree structure is built + using the structure-set of data. + **tree_estimator_params : dict Parameters to pass to the underlying base tree estimators. These must be parameters for ``tree_estimator``. @@ -283,9 +307,18 @@ class frequency in the voting subsample. ], "honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")], "honest_prior": [StrOptions({"empirical", "uniform", "ignore"})], + "honest_method": [StrOptions({"apply", "prune"}), None], "stratify": ["boolean"], "tree_estimator_params": ["dict"], } + _parameter_constraints.pop("max_features") + _parameter_constraints["max_features"] = [ + Interval(Integral, 1, None, closed="left"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), + StrOptions({"sqrt", "log2"}), + "array-like", + None, + ] def __init__( self, @@ -306,6 +339,7 @@ def __init__( honest_fraction=0.5, honest_prior="empirical", stratify=False, + honest_method="apply", **tree_estimator_params, ): self.tree_estimator = tree_estimator @@ -326,6 +360,7 @@ def __init__( self.honest_fraction = honest_fraction self.honest_prior = honest_prior self.stratify = stratify + self.honest_method = honest_method # XXX: to enable this, we need to also reset the leaf node samples during `_set_leaf_nodes` self.store_leaf_values = False @@ -664,16 +699,59 @@ def _fit_leaves(self, X, y, sample_weight): y = y_encoded self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) - # XXX: implement honest pruning - honest_method = "apply" - if honest_method == "apply": + if self.honest_method == "apply": # Fit leaves using other subsample honest_leaves = self.tree_.apply(X[self.honest_indices_]) # y-encoded ensures that y values match the indices of the classes self._set_leaf_nodes(honest_leaves, y, sample_weight) - elif honest_method == "prune": - raise NotImplementedError("Pruning is not yet implemented.") + elif self.honest_method == "prune": + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + n_samples = X.shape[0] + + # Build tree + criterion = self.criterion + if not isinstance(criterion, BaseCriterion): + if is_classifier(self): + criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self.n_classes_) + else: + criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + else: + # Make a deepcopy in case the criterion has mutable attributes that + # might be shared and modified concurrently during parallel fitting + criterion = copy.deepcopy(criterion) + + random_state = check_random_state(self.random_state) + pruner = HonestPruner( + criterion, + self.max_features_, + self.min_samples_leaf_, + self.min_weight_leaf_, + random_state, + self.monotonic_cst_, + self.tree_, + ) + + # build pruned tree + if is_classifier(self): + n_classes = np.atleast_1d(self.n_classes_) + pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + else: + pruned_tree = Tree( + self.n_features_in_, + # TODO: the tree shouldn't need this param + np.array([1] * self.n_outputs_, dtype=np.intp), + self.n_outputs_, + ) + + # get the leaves + missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X) + _build_pruned_tree_honesty( + pruned_tree, self.tree_, pruner, X, y, sample_weight, missing_values_in_feature_mask + ) + self.tree_ = pruned_tree if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] @@ -693,12 +771,27 @@ def _set_leaf_nodes(self, leaf_ids, y, sample_weight): """ self.tree_.value[:, :, :] = 0 + # XXX: Note this method does not make these into a proportion of the leaf + # total_n_node_samples = 0.0 + # apply sample-weight to the leaf nodes + # seen_leaf_ids = set() for leaf_id, yval, y_weight in zip( leaf_ids, y[self.honest_indices_, :], sample_weight[self.honest_indices_] ): + # XXX: this treats the leaf node values as a sum of the leaf self.tree_.value[leaf_id][:, yval] += y_weight + # XXX: this normalizes the leaf node values to be a proportion of the leaf + # total_n_node_samples += y_weight + # if leaf_id in seen_leaf_ids: + # self.tree_.value[leaf_id][:, yval] += y_weight + # else: + # self.tree_.value[leaf_id][:, yval] = y_weight + # seen_leaf_ids.add(leaf_id) + # for leaf_id in seen_leaf_ids: + # self.tree_.value[leaf_id] /= total_n_node_samples + def _inherit_estimator_attributes(self): """Initialize necessary attributes from the provided tree estimator""" if hasattr(self.estimator_, "_inheritable_fitted_attribute"): @@ -821,3 +914,16 @@ def predict(self, X, check_input=True): check_is_fitted(self) X = self._validate_X_predict(X, check_input) return self.estimator_.predict(X, False) + + @property + def feature_importances_(self): + """Feature importances. + + This is the impurity-based feature importances. The higher, the more important + that the feature was used in constructing the structure. + + Note: this does not give the feature importances relative for setting the + leaf node posterior estimates. + """ + # TODO: technically, the feature importances is built rn using the structure set + return super().feature_importances_ diff --git a/treeple/tree/honesty/__init__.py b/treeple/tree/honesty/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/treeple/tree/honesty/_honest_prune.pxd b/treeple/tree/honesty/_honest_prune.pxd new file mode 100644 index 000000000..c609ef048 --- /dev/null +++ b/treeple/tree/honesty/_honest_prune.pxd @@ -0,0 +1,62 @@ +from ..._lib.sklearn.tree._criterion cimport Criterion +from ..._lib.sklearn.tree._partitioner cimport shift_missing_values_to_left_if_required +from ..._lib.sklearn.tree._splitter cimport SplitRecord, Splitter +from ..._lib.sklearn.tree._tree cimport Node, ParentInfo, Tree +from ..._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int8_t, intp_t, uint8_t, uint32_t + + +# for each node, keep track of the node index and the parent index +# within the tree's node array +cdef struct PruningRecord: + intp_t node_idx + intp_t start + intp_t end + float64_t lower_bound + float64_t upper_bound + + +# TODO: this may break the notion of feature importances, as we don't set the node's impurity +# at the child nodes. +cdef class HonestPruner(Splitter): + cdef Tree tree # The tree to be pruned + cdef intp_t capacity # The maximum number of nodes in the pruned tree + cdef intp_t pos # The current position to split left/right children + cdef intp_t n_missing # The number of missing values in the feature currently considered + cdef uint8_t missing_go_to_left + + # TODO: only supports sparse for now. + cdef const float32_t[:, :] X + + cdef int init( + self, + object X, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + const uint8_t[::1] missing_values_in_feature_mask, + ) except -1 + + # This function is not used, and should be disabled for pruners + cdef int node_split( + self, + ParentInfo* parent_record, + SplitRecord* split, + ) except -1 nogil + + cdef bint check_node_partition_conditions( + self, + SplitRecord* current_split, + float64_t lower_bound, + float64_t upper_bound + ) noexcept nogil + + cdef inline intp_t n_left_samples( + self + ) noexcept nogil + cdef inline intp_t n_right_samples( + self + ) noexcept nogil + + cdef int partition_samples( + self, + intp_t node_idx, + ) noexcept nogil diff --git a/treeple/tree/honesty/_honest_prune.pyx b/treeple/tree/honesty/_honest_prune.pyx new file mode 100644 index 000000000..c11812837 --- /dev/null +++ b/treeple/tree/honesty/_honest_prune.pyx @@ -0,0 +1,429 @@ +# cython: boundscheck=False +# cython: wraparound=False +# cython: initializedcheck=False + +import numpy as np + +cimport numpy as cnp + +cnp.import_array() + +from libc.math cimport isnan +from libc.stdlib cimport free, malloc +from libcpp.stack cimport stack + +from ..._lib.sklearn.tree._tree cimport ParentInfo, _build_pruned_tree + +TREE_LEAF = -1 +cdef intp_t _TREE_LEAF = TREE_LEAF +cdef float64_t INFINITY = np.inf + +cdef inline void _init_parent_record(ParentInfo* record) noexcept nogil: + record.n_constant_features = 0 + record.lower_bound = -INFINITY + record.upper_bound = INFINITY + + +def _build_pruned_tree_honesty( + Tree tree, + Tree orig_tree, + HonestPruner pruner, + object X, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + const uint8_t[::1] missing_values_in_feature_mask=None, +): + """Prune an existing tree with honest splits. + + Parameters + ---------- + tree : Tree + The tree to be pruned. + orig_tree : Tree + The original tree to be pruned. + pruner : HonestPruner + The pruner to enforce honest splits. + X : array-like of shape (n_samples, n_features) + The input samples. + y : array-like of shape (n_samples,) + The target values. + sample_weight : array-like of shape (n_samples,) + The sample weights. + missing_values_in_feature_mask : array-like of shape (n_features,) + The mask of missing values in the features. + """ + cdef: + intp_t n_nodes = orig_tree.node_count + uint8_t[:] leaves_in_subtree = np.zeros( + shape=n_nodes, dtype=np.uint8) + + # initialize the pruner/splitter + pruner.init(X, y, sample_weight, missing_values_in_feature_mask) + + # apply pruning to the tree + _honest_prune(leaves_in_subtree, orig_tree, pruner) + + _build_pruned_tree(tree, orig_tree, leaves_in_subtree, + pruner.capacity) + + +cdef class HonestPruner(Splitter): + """Pruning to enforce honest splits are non-degenerate.""" + + def __cinit__( + self, + Criterion criterion, + intp_t max_features, + intp_t min_samples_leaf, + float64_t min_weight_leaf, + object random_state, + const int8_t[:] monotonic_cst, + Tree orig_tree, + *argv + ): + """ + Parameters + ---------- + criterion : Criterion + The criterion to measure the quality of a split. + + max_features : intp_t + The maximal number of randomly selected features which can be + considered for a split. + + min_samples_leaf : intp_t + The minimal number of samples each leaf can have, where splits + which would result in having less samples in a leaf are not + considered. + + min_weight_leaf : float64_t + The minimal weight each leaf can have, where the weight is the sum + of the weights of each sample in it. + + random_state : object + The user inputted random state to be used for pseudo-randomness + + monotonic_cst : const int8_t[:] + Monotonicity constraints + + orig_tree : Tree + The original tree to be pruned. + """ + self.tree = orig_tree + self.capacity = 0 + + cdef int init( + self, + object X, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, + const uint8_t[::1] missing_values_in_feature_mask, + ) except -1: + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + self.X = X + + cdef int partition_samples( + self, + intp_t node_idx, + ) noexcept nogil: + """Partition samples for X at the threshold and feature index of `orig_tree`. + + If missing values are present, this method partitions `samples` + so that the `best_n_missing` missing values' indices are in the + right-most end of `samples`, that is `samples[end_non_missing:end]`. + """ + cdef float64_t threshold = self.tree.nodes[node_idx].threshold + cdef intp_t feature = self.tree.nodes[node_idx].feature + cdef intp_t n_missing = 0 + cdef intp_t pos = self.start + cdef intp_t p + cdef intp_t sample_idx + cdef intp_t current_end = self.end - n_missing + cdef const float32_t[:, :] X_ndarray = self.X + + # partition the samples one by one by swapping them to the left or right + # of pos depending on the feature value compared to the orig_tree threshold + for p in range(self.start, self.end): + sample_idx = self.samples[p] + + # missing-values are always placed at the right-most end + if isnan(X_ndarray[sample_idx, feature]): + self.samples[p], self.samples[current_end] = \ + self.samples[current_end], self.samples[p] + + n_missing += 1 + current_end -= 1 + elif p > pos and X_ndarray[sample_idx, feature] <= threshold: + self.samples[p], self.samples[pos] = \ + self.samples[pos], self.samples[p] + pos += 1 + + # this is the split point for left/right children + self.pos = pos + self.n_missing = n_missing + self.missing_go_to_left = self.tree.nodes[node_idx].missing_go_to_left + + cdef bint check_node_partition_conditions( + self, + SplitRecord* current_split, + float64_t lower_bound, + float64_t upper_bound + ) noexcept nogil: + """Check that the current node satisfies paritioning conditions. + + Parameters + ---------- + current_split : SplitRecord pointer + A pointer to a memory-allocated SplitRecord object which will be filled with the + split chosen. + """ + # update the criterion if we are checking split conditions + self.criterion.init_missing(self.n_missing) + self.criterion.reset() + self.criterion.update(self.pos) + + current_split.pos = self.pos + current_split.n_missing = self.n_missing + current_split.missing_go_to_left = self.missing_go_to_left + + # first check the presplit conditions + cdef bint invalid_split = self.check_presplit_conditions( + current_split, + self.n_missing, + self.missing_go_to_left + ) + + if invalid_split: + return 0 + + # TODO: make work with lower/upper bound. This will require passing + # lower/upper bound from the parent node into check_node_partition_conditions + # Reject if monotonicity constraints are not satisfied + if ( + self.with_monotonic_cst and + self.monotonic_cst[current_split.feature] != 0 and + not self.criterion.check_monotonicity( + self.monotonic_cst[current_split.feature], + lower_bound, + upper_bound, + ) + ): + return 0 + + # Note this is called after pre-split condition checks + # shift missing values to left if required, so we can check + # the split conditions + shift_missing_values_to_left_if_required( + current_split, + self.samples, + self.end + ) + + # next check the postsplit conditions that leverages the criterion + invalid_split = self.check_postsplit_conditions() + return invalid_split + + cdef inline intp_t n_left_samples( + self + ) noexcept nogil: + """Number of samples to send to the left child.""" + cdef intp_t n_left + + if self.missing_go_to_left: + n_left = self.pos - self.start + self.n_missing + else: + n_left = self.pos - self.start + return n_left + + cdef inline intp_t n_right_samples( + self + ) noexcept nogil: + """Number of samples to send to the right child.""" + cdef intp_t n_right + cdef intp_t end_non_missing = self.end - self.n_missing + if self.missing_go_to_left: + n_right = end_non_missing - self.pos + else: + n_right = end_non_missing - self.pos + self.n_missing + return n_right + + cdef int node_split( + self, + ParentInfo* parent_record, + SplitRecord* split, + ) except -1 nogil: + """Split nodes using the already constructed tree. + + This is a simpler version of splitting nodes during the construction of a tree. + Here, we only need to split the samples in the node based on the feature and + threshold of the node in the original tree. In addition, we track the relevant + information from the parent node, such as lower/upper bounds, and the parent's + impurity, and n_constant_features. + + Returns 0 if a split cannot be done, 1 if a split can be done + and -1 in case of failure to allocate memory (and raise MemoryError). + """ + raise NotImplementedError("node_split is not used in honest pruning") + + +cdef _honest_prune( + uint8_t[:] leaves_in_subtree, + Tree orig_tree, + HonestPruner pruner, +): + """Perform honest pruning of the tree. + + Iterates through the original tree in a BFS fashion using the pruner + and tracks at each node (orig_node): + + - the number of samples in the node + - the number of samples that would be sent to the left and right children + + Until one of three stopping conditions are met: + + 1. The orig_node is a leaf node in the original tree. + Thus we keep the node as a leaf in the pruned tree. + 2. The orig_node is a non-leaf node and the split is degenerate. + Thus we would prune the subtree, and assign orig_node as a leaf + node in the pruned tree. + 3. Stopping criterion is met based on the samples that reach the node. + These are the stopping conditions implemented in a Splitter/Pruner. + Thus we would prune the subtree, and assign orig_node as a leaf + node in the pruned tree. + + Parameters + ---------- + leaves_in_subtree : array of shape (n_nodes,), dtype=np.uint8 + Array of booleans indicating whether the node is in the subtree. + orig_tree : Tree + The original tree. + pruner : HonestPruner + The input samples to be used for computing the split of samples + in the nodes. + """ + cdef: + # get the left child, right child and parents of every node + intp_t[:] child_l = orig_tree.children_left + intp_t[:] child_r = orig_tree.children_right + # intp_t[:] parents = np.zeros(shape=n_nodes, dtype=np.intp) + + # stack to keep track of the nodes to be pruned such that BFS is done + stack[PruningRecord] pruning_stack + PruningRecord stack_record + + intp_t node_idx + SplitRecord* split_ptr = malloc(pruner.pointer_size()) + + bint is_leaf_in_origtree + bint invalid_split + bint split_is_degenerate + + intp_t start = 0 + intp_t end = 0 + float64_t weighted_n_node_samples + + float64_t lower_bound, upper_bound + float64_t left_child_min, left_child_max, right_child_min, right_child_max, middle_value + + # find parent node ids and leaves + with nogil: + # Push the root node + pruning_stack.push({ + "node_idx": 0, + "start": 0, + "end": pruner.n_samples, + "lower_bound": -INFINITY, + "upper_bound": INFINITY, + }) + + # Note: this DFS building strategy differs from scikit-learn in that + # we check stopping conditions (and leaf candidacy) after a split occurs. + # If we don't hit a leaf, then we will add the children to the stack, but otherwise + # we will halt the split, and mark the node to be a new leaf node in the pruned tree. + while not pruning_stack.empty(): + stack_record = pruning_stack.top() + pruning_stack.pop() + start = stack_record.start + end = stack_record.end + lower_bound = stack_record.lower_bound + upper_bound = stack_record.upper_bound + + # node index of actual node within the orig_tree + node_idx = stack_record.node_idx + + # reset which samples indices are considered at this split node + pruner.node_reset(start, end, &weighted_n_node_samples) + + # partition samples into left/right child based on the + # current node split in the orig_tree + pruner.partition_samples(node_idx) + + # check end conditions + split_ptr.feature = orig_tree.nodes[node_idx].feature + invalid_split = pruner.check_node_partition_conditions( + split_ptr, + lower_bound, + upper_bound + ) + split_is_degenerate = ( + pruner.n_left_samples() == 0 or pruner.n_right_samples() == 0 + ) + is_leaf_in_origtree = child_l[node_idx] == _TREE_LEAF + if invalid_split or split_is_degenerate or is_leaf_in_origtree: + # ... and child_r[node_idx] == _TREE_LEAF: + # + # 1) if node is not degenerate, that means there are still honest-samples in + # both left/right children of the proposed split, or the node itself is a leaf + # or 2) there are still nodes to split on, but the honest-samples have been + # used up so the "parent" should be the new leaf + leaves_in_subtree[node_idx] = 1 + else: + if ( + not pruner.with_monotonic_cst or + pruner.monotonic_cst[split_ptr.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound + elif pruner.monotonic_cst[split_ptr.feature] == 1: + # Split on a feature with monotonic increase constraint + left_child_min = lower_bound + right_child_max = upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = pruner.criterion.middle_value() + right_child_min = middle_value + left_child_max = middle_value + else: # i.e. pruner.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint + right_child_min = lower_bound + left_child_max = upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = pruner.criterion.middle_value() + left_child_min = middle_value + right_child_max = middle_value + + pruning_stack.push({ + "node_idx": child_l[node_idx], + "start": pruner.start, + "end": pruner.pos, + "lower_bound": left_child_min, + "upper_bound": left_child_max, + }) + pruning_stack.push({ + "node_idx": child_r[node_idx], + "start": pruner.pos, + "end": pruner.end, + "lower_bound": right_child_min, + "upper_bound": right_child_max, + }) + + # free the memory created for the SplitRecord pointer + free(split_ptr) diff --git a/treeple/tree/honesty/meson.build b/treeple/tree/honesty/meson.build new file mode 100644 index 000000000..acb21b192 --- /dev/null +++ b/treeple/tree/honesty/meson.build @@ -0,0 +1,23 @@ +tree_extension_metadata = { + '_honest_prune': + {'sources': ['_honest_prune.pyx'], + 'override_options': ['cython_language=cpp', 'optimization=3']}, +} + +foreach ext_name, ext_dict : tree_extension_metadata + py.extension_module( + ext_name, + ext_dict.get('sources'), + dependencies: [np_dep], + override_options : ext_dict.get('override_options', []), + c_args: c_args, + cython_args: cython_c_args, + subdir: 'treeple/tree/honesty', + install: true, + ) +endforeach + + +py.install_sources( + subdir: 'treeple/tree/honesty' # Folder relative to site-packages to install to +) diff --git a/treeple/tree/meson.build b/treeple/tree/meson.build index 2febbbc6a..e8c8ecce8 100644 --- a/treeple/tree/meson.build +++ b/treeple/tree/meson.build @@ -43,7 +43,7 @@ py.install_sources( subdir: 'treeple/tree' # Folder relative to site-packages to install to ) -# TODO: comment in if we include tests subdir('tests') subdir('unsupervised') subdir('manifold') +subdir('honesty') \ No newline at end of file diff --git a/treeple/tree/tests/meson.build b/treeple/tree/tests/meson.build index 552eb7356..99cbdfa9f 100644 --- a/treeple/tree/tests/meson.build +++ b/treeple/tree/tests/meson.build @@ -3,6 +3,7 @@ python_sources = [ 'test_tree.py', 'test_utils.py', 'test_honest_tree.py', + 'test_honest_prune.py', 'test_marginal.py', 'test_all_trees.py', 'test_unsupervised_tree.py', diff --git a/treeple/tree/tests/test_honest_prune.py b/treeple/tree/tests/test_honest_prune.py new file mode 100644 index 000000000..629952821 --- /dev/null +++ b/treeple/tree/tests/test_honest_prune.py @@ -0,0 +1,72 @@ +import numpy as np + +from treeple.tree import HonestTreeClassifier + + +def test_honest_tree_pruning(): + """Test honest tree with pruning to ensure no empty leaves.""" + rng = np.random.default_rng(1234) + + n_samples = 1000 + X = rng.standard_normal(size=(n_samples, 100)) + X[n_samples // 2 :] *= -1 + y = [0] * (n_samples // 2) + [1] * (n_samples // 2) + + clf = HonestTreeClassifier(honest_method="prune", max_features="sqrt", random_state=0) + clf = clf.fit(X, y) + + nonprune_clf = HonestTreeClassifier( + honest_method="apply", max_features="sqrt", random_state=0, honest_prior="ignore" + ) + nonprune_clf = nonprune_clf.fit(X, y) + + assert ( + nonprune_clf.tree_.max_depth >= clf.tree_.max_depth + ), f"{nonprune_clf.tree_.max_depth} <= {clf.tree_.max_depth}" + # assert np.all(clf.tree_.children_left != -1) + + # Access the original and pruned trees' attributes + original_tree = nonprune_clf.tree_ + pruned_tree = clf.tree_ + + # Ensure the pruned tree has fewer or equal nodes + assert ( + pruned_tree.node_count < original_tree.node_count + ), "Pruned tree has more nodes than the original tree" + + # Ensure the pruned tree has no empty leaves + assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0), pruned_tree.value.sum(axis=(1, 2)) + # assert np.all(original_tree.value.sum(axis=(1,2)) > 0), original_tree.value.sum(axis=(1,2)) + assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0) > np.all( + original_tree.value.sum(axis=(1, 2)) > 0 + ) + + # test that the first three nodes are the same, since these are unlikely to be + # pruned, and should remain invariant. + # + # Note: pruning the tree will have the node_ids change since the tree is + # ordered via DFS. + for pruned_node_id in range(3): + pruned_left_child = pruned_tree.children_left[pruned_node_id] + pruned_right_child = pruned_tree.children_right[pruned_node_id] + + # Check if the pruned node exists in the original tree + assert ( + pruned_left_child in original_tree.children_left + ), "Left child node of pruned tree not found in original tree" + assert ( + pruned_right_child in original_tree.children_right + ), "Right child node of pruned tree not found in original tree" + + # Check if the node's parameters match for non-leaf nodes + if pruned_left_child != -1: + assert ( + pruned_tree.feature[pruned_node_id] == original_tree.feature[pruned_node_id] + ), "Feature does not match for node {}".format(pruned_node_id) + assert ( + pruned_tree.threshold[pruned_node_id] == original_tree.threshold[pruned_node_id] + ), "Threshold does not match for node {}".format(pruned_node_id) + assert ( + pruned_tree.weighted_n_node_samples[pruned_node_id] + == original_tree.weighted_n_node_samples[pruned_node_id] + ), "Weighted n_node samples does not match for node {}".format(pruned_node_id) diff --git a/treeple/tree/tests/test_honest_tree.py b/treeple/tree/tests/test_honest_tree.py index bdc714e55..f7542abc9 100644 --- a/treeple/tree/tests/test_honest_tree.py +++ b/treeple/tree/tests/test_honest_tree.py @@ -36,7 +36,10 @@ def test_iris(criterion, max_features, estimator): # Check consistency on dataset iris. clf = HonestTreeClassifier( - criterion=criterion, random_state=0, max_features=max_features, tree_estimator=estimator + criterion=criterion, + random_state=0, + max_features=max_features, + tree_estimator=estimator, ) clf.fit(iris.data, iris.target) score = accuracy_score(clf.predict(iris.data), iris.target) @@ -54,8 +57,9 @@ def test_iris(criterion, max_features, estimator): assert len(clf.structure_indices_) < len(iris.target) -def test_toy_accuracy(): - clf = HonestTreeClassifier() +@pytest.mark.parametrize("honest_method", ["apply", "prune"]) +def test_toy_accuracy(honest_method): + clf = HonestTreeClassifier(honest_method=honest_method) X = np.ones((20, 4)) X[10:] *= -1 y = [0] * 10 + [1] * 10 @@ -175,6 +179,9 @@ def test_sklearn_compatible_estimator(estimator, check): "check_class_weight_classifiers", "check_classifier_multioutput", "check_do_not_raise_errors_in_init_or_set_params", + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", ]: pytest.skip() check(estimator) diff --git a/treeple/tree/tests/test_multiview.py b/treeple/tree/tests/test_multiview.py index 4b36c6fbd..424e3f50c 100644 --- a/treeple/tree/tests/test_multiview.py +++ b/treeple/tree/tests/test_multiview.py @@ -19,6 +19,11 @@ ] ) def test_sklearn_compatible_estimator(estimator, check): + if check.func.__name__ in [ + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", + ]: + pytest.skip() check(estimator) diff --git a/treeple/tree/tests/test_tree.py b/treeple/tree/tests/test_tree.py index feab02aaf..9ae39f4d1 100644 --- a/treeple/tree/tests/test_tree.py +++ b/treeple/tree/tests/test_tree.py @@ -229,6 +229,13 @@ def test_sklearn_compatible_estimator(estimator, check): estimator, (PatchObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeClassifier) ) and check.func.__name__ in ["check_fit_score_takes_y"]: pytest.skip() + + if check.func.__name__ in [ + "check_sample_weight_equivalence_on_sparse_data", + "check_sample_weight_equivalence_on_dense_data", + ]: + pytest.skip() + check(estimator) diff --git a/treeple/tree/tests/test_unsupervised_tree.py b/treeple/tree/tests/test_unsupervised_tree.py index 84e7d3d62..50fdf5f0a 100644 --- a/treeple/tree/tests/test_unsupervised_tree.py +++ b/treeple/tree/tests/test_unsupervised_tree.py @@ -107,7 +107,9 @@ def test_sklearn_compatible_transformer(estimator, check): # clustering accuracy is poor when using TwoMeans on 1 single tree "check_clustering", # sample weights do not necessarily imply a sample is not used in clustering - "check_sample_weights_invariance", + "check_sample_weight_equivalence", + "check_sample_weight_equivalence_on_dense_data", + "check_sample_weight_equivalence_on_sparse_data", # sample order is not preserved in predict "check_methods_sample_order_invariance", ]: