Skip to content

Commit

Permalink
FEA Implement pruning using honest subsample data to fit the leaves (#…
Browse files Browse the repository at this point in the history
…286)

* Honest pruning in honest tree classifier
* TST turn off sklearn multi_output tag assignment

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Haoyin Xu <haoyinxu@gmail.com>
adam2392 and PSSF23 authored Dec 12, 2024
1 parent e1c38ad commit ab12ca9
Showing 26 changed files with 931 additions and 47 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.10.rst
Original file line number Diff line number Diff line change
@@ -17,6 +17,10 @@ Changelog
``bottleneck`` library for faster computation. By `Ryan Hausen`_ (:pr:`#306`)
- |Feature| Added a sparse implementation of `treeple.stats.forest.build_colemen_forest`
that uses the `scipy.sparse` module. By `Ryan Hausen`_ (:pr:`#317`)
- |Feature| :class:`treeple.tree.HonestTreeClassifier` now has a ``honest_method`` parameter
that enables the user to turn on pruning of the tree, such that there are no
empty leaf predictions. This brings the model closer to the implementation in GRF in R.
By `Adam Li`_ (:pr:`#286`)


Code and Documentation Contributors
117 changes: 117 additions & 0 deletions examples/calibration/plot_honest_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
===========================================
Comparison of Decision Tree and Honest Tree
===========================================
This example compares the :class:`treeple.tree.HonestTreeClassifier` from the
``treeple`` library with the :class:`sklearn.tree.DecisionTreeClassifier`
from scikit-learn on the Iris dataset.
Both classifiers are fitted on the same dataset and their decision trees
are plotted side by side.
"""

import matplotlib.pyplot as plt
from sklearn import config_context
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree

from treeple.tree import HonestTreeClassifier

# Load the iris dataset
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=0)

# Initialize classifiers
max_features = 0.3

dishonest_clf = HonestTreeClassifier(
honest_method=None,
max_features=max_features,
random_state=0,
honest_prior="ignore",
)
honest_noprune_clf = HonestTreeClassifier(
honest_method="apply",
max_features=max_features,
random_state=0,
honest_prior="ignore",
)
honest_clf = HonestTreeClassifier(honest_method="prune", max_features=max_features, random_state=0)
sklearn_clf = DecisionTreeClassifier(max_features=max_features, random_state=0)

# Fit classifiers
dishonest_clf.fit(X_train, y_train)
honest_noprune_clf.fit(X_train, y_train)
honest_clf.fit(X_train, y_train)
sklearn_clf.fit(X_train, y_train)

# Plotting the trees
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(15, 5))

# .. note:: We skip parameter validation because internally the `plot_tree`
# function checks if the estimator is a DecisionTreeClassifier
# instance from scikit-learn, but the ``HonestTreeClassifier`` is
# a subclass of a forked version of the DecisionTreeClassifier.

# Plot HonestTreeClassifier tree
ax = axes[2]
with config_context(skip_parameter_validation=True):
plot_tree(honest_clf, filled=True, ax=ax)
ax.set_title("HonestTreeClassifier")

# Plot HonestTreeClassifier tree
ax = axes[1]
with config_context(skip_parameter_validation=True):
plot_tree(honest_noprune_clf, filled=False, ax=ax)
ax.set_title("HonestTreeClassifier (No pruning)")

# Plot HonestTreeClassifier tree
ax = axes[0]
with config_context(skip_parameter_validation=True):
plot_tree(dishonest_clf, filled=False, ax=ax)
ax.set_title("HonestTreeClassifier (Dishonest)")


# Plot scikit-learn DecisionTreeClassifier tree
plot_tree(sklearn_clf, filled=True, ax=axes[3])
axes[3].set_title("DecisionTreeClassifier")

plt.show()

# %%
# Discussion
# ----------
# The HonestTreeClassifier is a variant of the DecisionTreeClassifier that
# provides honest inference. The honest inference is achieved by splitting the
# dataset into two parts: the training set and the validation set. The training
# set is used to build the tree, while the validation set is used to fit the
# leaf nodes for posterior prediction. This results in calibrated posteriors
# (see :ref:`sphx_glr_auto_examples_calibration_plot_overlapping_gaussians.py`).
#
# Compared to the ``honest_prior='apply'`` method, the ``honest_prior='prune'``
# method builds a tree that will not contain empty leaves, and also leverages
# the validation set to check split conditions. Thus we see that the pruned
# honest tree is significantly smaller than the regular decision tree.

# %%
# Evaluate predictions of the trees
# ---------------------------------
# When we do not prune, note that the honest tree will have empty leaves
# that predict the prior. In this case, ``honest_prior='ignore'`` is used
# to ignore these leaves when computing the posteriors, which will result
# in a posterior that is ``np.nan``.

# this is the same as a decision tree classifier that is trained on less data
print("\nDishonest posteriors: ", dishonest_clf.predict_proba(X_val))

# this is the honest tree with empty leaves that predict the prior
print("\nHonest tree without pruning: ", honest_noprune_clf.predict_proba(X_val))

# this is the honest tree that is pruned
print("\nHonest tree with pruning: ", honest_clf.predict_proba(X_val))

# this is a regular decision tree classifier from sklearn
print("\nDTC: ", sklearn_clf.predict_proba(X_val))
2 changes: 2 additions & 0 deletions examples/calibration/plot_overlapping_gaussians.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""
.. _plot_overlapping_gaussians:
===================================================================
Plot honest forest calibrations on overlapping gaussian simulations
===================================================================
2 changes: 1 addition & 1 deletion treeple/__init__.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import os
import sys

__version__ = "0.9.0dev0"
__version__ = "0.10.0dev0"
logger = logging.getLogger(__name__)


17 changes: 16 additions & 1 deletion treeple/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
@@ -270,6 +270,11 @@ class HonestForestClassifier(ForestClassifier, ForestClassifierMixin):
Fraction of training samples used for estimates in the trees. The
remaining samples will be used to learn the tree structure. A larger
fraction creates shallower trees with lower variance estimates.
honest_method : {"prune", "apply"}, default="prune"
Method for enforcing honesty. If "prune", the tree is pruned to enforce
honesty. If "apply", the tree is not pruned, but the leaf estimates are
adjusted to enforce honesty.
tree_estimator : object, default=None
Instantiated tree of type BaseDecisionTree from treeple.
@@ -410,6 +415,13 @@ class labels (multi-output problem).

_parameter_constraints: dict = {
**ForestClassifier._parameter_constraints,
**HonestTreeClassifier._parameter_constraints,
"class_weight": [
StrOptions({"balanced_subsample", "balanced"}),
dict,
list,
None,
],
}
_parameter_constraints.pop("max_samples")
_parameter_constraints["max_samples"] = [
@@ -453,6 +465,7 @@ def __init__(
max_samples=None,
honest_prior="ignore",
honest_fraction=0.5,
honest_method="apply",
tree_estimator=None,
stratify=False,
**tree_estimator_params,
@@ -475,6 +488,7 @@ def __init__(
"tree_estimator",
"honest_fraction",
"honest_prior",
"honest_method",
"stratify",
),
bootstrap=bootstrap,
@@ -498,6 +512,7 @@ def __init__(
self.ccp_alpha = ccp_alpha
self.honest_fraction = honest_fraction
self.honest_prior = honest_prior
self.honest_method = honest_method
self.tree_estimator = tree_estimator
self.stratify = stratify
self._tree_estimator_params = tree_estimator_params
@@ -730,7 +745,7 @@ def oob_samples_(self):
def __sklearn_tags__(self):
# XXX: nans should be supportable in HRF
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_output = False
# tags.classifier_tags.multi_output = False
tags.input_tags.allow_nan = False
return tags

3 changes: 3 additions & 0 deletions treeple/experimental/tests/test_sdf.py
Original file line number Diff line number Diff line change
@@ -115,6 +115,9 @@ def test_sklearn_compatible_estimator(estimator, check):
# XXX: can include this "generalization" in the future if it's useful
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
]:
pytest.skip()
check(estimator)
2 changes: 1 addition & 1 deletion treeple/meson.build
Original file line number Diff line number Diff line change
@@ -96,7 +96,7 @@ cc = meson.get_compiler('c')
# 'source_fname',
# numpy_nodepr_api)

# XXX: ENABLE WHEN DEBUGGING
# TODO XXX: ENABLE WHEN DEBUGGING
boundscheck = 'False'

scikit_learn_cython_args = [
2 changes: 1 addition & 1 deletion treeple/neighbors.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix


class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin):
class NearestNeighborsMetaEstimator(MetaEstimatorMixin, BaseEstimator):
"""Meta-estimator for nearest neighbors.
Uses a decision-tree, or forest model to compute distances between samples
2 changes: 1 addition & 1 deletion treeple/stats/forest.py
Original file line number Diff line number Diff line change
@@ -289,7 +289,7 @@ def build_oob_forest(
# the Histogram Gradient Boosting Tree does, where the binning thresholds
# are passed into the tree itself, thus allowing us to set the node feature
# value thresholds within the tree itself.
if est.max_bins is not None:
if hasattr(est, "max_bins") and est.max_bins is not None:
X = est._bin_data(X, is_training_data=False).astype(DTYPE)

# Assign chunk of trees to jobs
53 changes: 30 additions & 23 deletions treeple/stats/permuteforest.py
Original file line number Diff line number Diff line change
@@ -179,6 +179,11 @@ class PermutationHonestForestClassifier(HonestForestClassifier):
remaining samples will be used to learn the tree structure. A larger
fraction creates shallower trees with lower variance estimates.
honest_method : {"prune", "apply"}, default="prune"
Method for enforcing honesty. If "prune", the tree is pruned to enforce
honesty. If "apply", the tree is not pruned, but the leaf estimates are
adjusted to enforce honesty.
tree_estimator : object, default=None
Type of decision tree classifier to use. By default `None`, which
defaults to `treeple.tree.DecisionTreeClassifier`. Note
@@ -298,35 +303,37 @@ def __init__(
max_samples=None,
honest_prior="empirical",
honest_fraction=0.5,
honest_method="apply",
tree_estimator=None,
stratify=False,
permute_per_tree=False,
**tree_estimator_params,
):
super().__init__(
n_estimators,
criterion,
splitter,
max_depth,
min_samples_split,
min_samples_leaf,
min_weight_fraction_leaf,
max_features,
max_leaf_nodes,
min_impurity_decrease,
bootstrap,
oob_score,
n_jobs,
random_state,
verbose,
warm_start,
class_weight,
ccp_alpha,
max_samples,
honest_prior,
honest_fraction,
tree_estimator,
stratify,
n_estimators=n_estimators,
criterion=criterion,
splitter=splitter,
max_depth=max_depth,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight,
ccp_alpha=ccp_alpha,
max_samples=max_samples,
honest_prior=honest_prior,
honest_fraction=honest_fraction,
honest_method=honest_method,
tree_estimator=tree_estimator,
stratify=stratify,
**tree_estimator_params,
)
self.permute_per_tree = permute_per_tree
3 changes: 3 additions & 0 deletions treeple/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
@@ -310,6 +310,9 @@ def test_sklearn_compatible_estimator(estimator, check):
# for fitting the tree's splits
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
# TODO: this is an error. Somehow a segfault is raised when fit is called first and
# then partial_fit
"check_fit_score_takes_y",
7 changes: 7 additions & 0 deletions treeple/tests/test_multiview_forest.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,13 @@
]
)
def test_sklearn_compatible_estimator(estimator, check):
if check.func.__name__ in [
# sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
]:
pytest.skip()
check(estimator)


12 changes: 11 additions & 1 deletion treeple/tests/test_supervised_forest.py
Original file line number Diff line number Diff line change
@@ -200,7 +200,17 @@ def test_sklearn_compatible_estimator(estimator, check):
ObliqueRandomForestClassifier,
PatchObliqueRandomForestClassifier,
),
) and check.func.__name__ in ["check_fit_score_takes_y"]:
) and check.func.__name__ in [
"check_fit_score_takes_y",
]:
pytest.skip()

if check.func.__name__ in [
# sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
]:
pytest.skip()
check(estimator)

8 changes: 5 additions & 3 deletions treeple/tests/test_unsupervised_forest.py
Original file line number Diff line number Diff line change
@@ -33,9 +33,11 @@ def test_sklearn_compatible_estimator(estimator, check):
if check.func.__name__ in [
# Cannot apply agglomerative clustering on < 2 samples
"check_methods_subset_invariance",
# # sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weights_invariance",
# # sample order is not preserved in predict
# sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
# sample order is not preserved in predict
"check_methods_sample_order_invariance",
]:
pytest.skip()
126 changes: 116 additions & 10 deletions treeple/tree/_honest_tree.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
# Adopted from: https://github.com/neurodata/honest-forests

from copy import copy
from numbers import Integral

import numpy as np
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context, clone
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, _fit_context, clone, is_classifier
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
from sklearn.utils.multiclass import _check_partial_fit_first_call, check_classification_targets
from sklearn.utils.validation import check_is_fitted, check_X_y
from sklearn.utils.validation import check_is_fitted, check_random_state, check_X_y

from .._lib.sklearn.tree import DecisionTreeClassifier
from .._lib.sklearn.tree import DecisionTreeClassifier, _criterion, _tree
from .._lib.sklearn.tree._classes import BaseDecisionTree
from .._lib.sklearn.tree._criterion import BaseCriterion
from .._lib.sklearn.tree._tree import Tree
from .honesty._honest_prune import HonestPruner, _build_pruned_tree_honesty

CRITERIA_CLF = {
"gini": _criterion.Gini,
"log_loss": _criterion.Entropy,
"entropy": _criterion.Entropy,
}
CRITERIA_REG = {
"squared_error": _criterion.MSE,
"friedman_mse": _criterion.FriedmanMSE,
"absolute_error": _criterion.MAE,
"poisson": _criterion.Poisson,
}

DOUBLE = _tree.DOUBLE


class HonestTreeClassifier(MetaEstimatorMixin, ClassifierMixin, BaseDecisionTree):
@@ -173,6 +190,13 @@ class frequency in the voting subsample.
Whether or not to stratify sample when considering structure and leaf indices.
By default False.
honest_method : {"apply", "prune"}, default="apply"
Method to use for fitting the leaf nodes. If "apply", the leaf nodes
are fit using the structure as is. In this case, empty leaves may occur
if not enough data. If "prune", the leaf nodes are fit
by pruning using the honest-set of data after the tree structure is built
using the structure-set of data.
**tree_estimator_params : dict
Parameters to pass to the underlying base tree estimators.
These must be parameters for ``tree_estimator``.
@@ -283,9 +307,18 @@ class frequency in the voting subsample.
],
"honest_fraction": [Interval(RealNotInt, 0.0, 1.0, closed="neither")],
"honest_prior": [StrOptions({"empirical", "uniform", "ignore"})],
"honest_method": [StrOptions({"apply", "prune"}), None],
"stratify": ["boolean"],
"tree_estimator_params": ["dict"],
}
_parameter_constraints.pop("max_features")
_parameter_constraints["max_features"] = [
Interval(Integral, 1, None, closed="left"),
Interval(RealNotInt, 0.0, 1.0, closed="right"),
StrOptions({"sqrt", "log2"}),
"array-like",
None,
]

def __init__(
self,
@@ -306,6 +339,7 @@ def __init__(
honest_fraction=0.5,
honest_prior="empirical",
stratify=False,
honest_method="apply",
**tree_estimator_params,
):
self.tree_estimator = tree_estimator
@@ -326,6 +360,7 @@ def __init__(
self.honest_fraction = honest_fraction
self.honest_prior = honest_prior
self.stratify = stratify
self.honest_method = honest_method

# XXX: to enable this, we need to also reset the leaf node samples during `_set_leaf_nodes`
self.store_leaf_values = False
@@ -664,16 +699,59 @@ def _fit_leaves(self, X, y, sample_weight):
y = y_encoded
self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

# XXX: implement honest pruning
honest_method = "apply"
if honest_method == "apply":
if self.honest_method == "apply":
# Fit leaves using other subsample
honest_leaves = self.tree_.apply(X[self.honest_indices_])

# y-encoded ensures that y values match the indices of the classes
self._set_leaf_nodes(honest_leaves, y, sample_weight)
elif honest_method == "prune":
raise NotImplementedError("Pruning is not yet implemented.")
elif self.honest_method == "prune":
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

n_samples = X.shape[0]

# Build tree
criterion = self.criterion
if not isinstance(criterion, BaseCriterion):
if is_classifier(self):
criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, self.n_classes_)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
else:
# Make a deepcopy in case the criterion has mutable attributes that
# might be shared and modified concurrently during parallel fitting
criterion = copy.deepcopy(criterion)

random_state = check_random_state(self.random_state)
pruner = HonestPruner(
criterion,
self.max_features_,
self.min_samples_leaf_,
self.min_weight_leaf_,
random_state,
self.monotonic_cst_,
self.tree_,
)

# build pruned tree
if is_classifier(self):
n_classes = np.atleast_1d(self.n_classes_)
pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_)
else:
pruned_tree = Tree(
self.n_features_in_,
# TODO: the tree shouldn't need this param
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_,
)

# get the leaves
missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X)
_build_pruned_tree_honesty(
pruned_tree, self.tree_, pruner, X, y, sample_weight, missing_values_in_feature_mask
)
self.tree_ = pruned_tree

if self.n_outputs_ == 1:
self.n_classes_ = self.n_classes_[0]
@@ -693,12 +771,27 @@ def _set_leaf_nodes(self, leaf_ids, y, sample_weight):
"""
self.tree_.value[:, :, :] = 0

# XXX: Note this method does not make these into a proportion of the leaf
# total_n_node_samples = 0.0

# apply sample-weight to the leaf nodes
# seen_leaf_ids = set()
for leaf_id, yval, y_weight in zip(
leaf_ids, y[self.honest_indices_, :], sample_weight[self.honest_indices_]
):
# XXX: this treats the leaf node values as a sum of the leaf
self.tree_.value[leaf_id][:, yval] += y_weight

# XXX: this normalizes the leaf node values to be a proportion of the leaf
# total_n_node_samples += y_weight
# if leaf_id in seen_leaf_ids:
# self.tree_.value[leaf_id][:, yval] += y_weight
# else:
# self.tree_.value[leaf_id][:, yval] = y_weight
# seen_leaf_ids.add(leaf_id)
# for leaf_id in seen_leaf_ids:
# self.tree_.value[leaf_id] /= total_n_node_samples

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
if hasattr(self.estimator_, "_inheritable_fitted_attribute"):
@@ -821,3 +914,16 @@ def predict(self, X, check_input=True):
check_is_fitted(self)
X = self._validate_X_predict(X, check_input)
return self.estimator_.predict(X, False)

@property
def feature_importances_(self):
"""Feature importances.
This is the impurity-based feature importances. The higher, the more important
that the feature was used in constructing the structure.
Note: this does not give the feature importances relative for setting the
leaf node posterior estimates.
"""
# TODO: technically, the feature importances is built rn using the structure set
return super().feature_importances_
Empty file.
62 changes: 62 additions & 0 deletions treeple/tree/honesty/_honest_prune.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from ..._lib.sklearn.tree._criterion cimport Criterion
from ..._lib.sklearn.tree._partitioner cimport shift_missing_values_to_left_if_required
from ..._lib.sklearn.tree._splitter cimport SplitRecord, Splitter
from ..._lib.sklearn.tree._tree cimport Node, ParentInfo, Tree
from ..._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int8_t, intp_t, uint8_t, uint32_t


# for each node, keep track of the node index and the parent index
# within the tree's node array
cdef struct PruningRecord:
intp_t node_idx
intp_t start
intp_t end
float64_t lower_bound
float64_t upper_bound


# TODO: this may break the notion of feature importances, as we don't set the node's impurity
# at the child nodes.
cdef class HonestPruner(Splitter):
cdef Tree tree # The tree to be pruned
cdef intp_t capacity # The maximum number of nodes in the pruned tree
cdef intp_t pos # The current position to split left/right children
cdef intp_t n_missing # The number of missing values in the feature currently considered
cdef uint8_t missing_go_to_left

# TODO: only supports sparse for now.
cdef const float32_t[:, :] X

cdef int init(
self,
object X,
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
const uint8_t[::1] missing_values_in_feature_mask,
) except -1

# This function is not used, and should be disabled for pruners
cdef int node_split(
self,
ParentInfo* parent_record,
SplitRecord* split,
) except -1 nogil

cdef bint check_node_partition_conditions(
self,
SplitRecord* current_split,
float64_t lower_bound,
float64_t upper_bound
) noexcept nogil

cdef inline intp_t n_left_samples(
self
) noexcept nogil
cdef inline intp_t n_right_samples(
self
) noexcept nogil

cdef int partition_samples(
self,
intp_t node_idx,
) noexcept nogil
429 changes: 429 additions & 0 deletions treeple/tree/honesty/_honest_prune.pyx

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions treeple/tree/honesty/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
tree_extension_metadata = {
'_honest_prune':
{'sources': ['_honest_prune.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
}

foreach ext_name, ext_dict : tree_extension_metadata
py.extension_module(
ext_name,
ext_dict.get('sources'),
dependencies: [np_dep],
override_options : ext_dict.get('override_options', []),
c_args: c_args,
cython_args: cython_c_args,
subdir: 'treeple/tree/honesty',
install: true,
)
endforeach


py.install_sources(
subdir: 'treeple/tree/honesty' # Folder relative to site-packages to install to
)
2 changes: 1 addition & 1 deletion treeple/tree/meson.build
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ py.install_sources(
subdir: 'treeple/tree' # Folder relative to site-packages to install to
)

# TODO: comment in if we include tests
subdir('tests')
subdir('unsupervised')
subdir('manifold')
subdir('honesty')
1 change: 1 addition & 0 deletions treeple/tree/tests/meson.build
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ python_sources = [
'test_tree.py',
'test_utils.py',
'test_honest_tree.py',
'test_honest_prune.py',
'test_marginal.py',
'test_all_trees.py',
'test_unsupervised_tree.py',
72 changes: 72 additions & 0 deletions treeple/tree/tests/test_honest_prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np

from treeple.tree import HonestTreeClassifier


def test_honest_tree_pruning():
"""Test honest tree with pruning to ensure no empty leaves."""
rng = np.random.default_rng(1234)

n_samples = 1000
X = rng.standard_normal(size=(n_samples, 100))
X[n_samples // 2 :] *= -1
y = [0] * (n_samples // 2) + [1] * (n_samples // 2)

clf = HonestTreeClassifier(honest_method="prune", max_features="sqrt", random_state=0)
clf = clf.fit(X, y)

nonprune_clf = HonestTreeClassifier(
honest_method="apply", max_features="sqrt", random_state=0, honest_prior="ignore"
)
nonprune_clf = nonprune_clf.fit(X, y)

assert (
nonprune_clf.tree_.max_depth >= clf.tree_.max_depth
), f"{nonprune_clf.tree_.max_depth} <= {clf.tree_.max_depth}"
# assert np.all(clf.tree_.children_left != -1)

# Access the original and pruned trees' attributes
original_tree = nonprune_clf.tree_
pruned_tree = clf.tree_

# Ensure the pruned tree has fewer or equal nodes
assert (
pruned_tree.node_count < original_tree.node_count
), "Pruned tree has more nodes than the original tree"

# Ensure the pruned tree has no empty leaves
assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0), pruned_tree.value.sum(axis=(1, 2))
# assert np.all(original_tree.value.sum(axis=(1,2)) > 0), original_tree.value.sum(axis=(1,2))
assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0) > np.all(
original_tree.value.sum(axis=(1, 2)) > 0
)

# test that the first three nodes are the same, since these are unlikely to be
# pruned, and should remain invariant.
#
# Note: pruning the tree will have the node_ids change since the tree is
# ordered via DFS.
for pruned_node_id in range(3):
pruned_left_child = pruned_tree.children_left[pruned_node_id]
pruned_right_child = pruned_tree.children_right[pruned_node_id]

# Check if the pruned node exists in the original tree
assert (
pruned_left_child in original_tree.children_left
), "Left child node of pruned tree not found in original tree"
assert (
pruned_right_child in original_tree.children_right
), "Right child node of pruned tree not found in original tree"

# Check if the node's parameters match for non-leaf nodes
if pruned_left_child != -1:
assert (
pruned_tree.feature[pruned_node_id] == original_tree.feature[pruned_node_id]
), "Feature does not match for node {}".format(pruned_node_id)
assert (
pruned_tree.threshold[pruned_node_id] == original_tree.threshold[pruned_node_id]
), "Threshold does not match for node {}".format(pruned_node_id)
assert (
pruned_tree.weighted_n_node_samples[pruned_node_id]
== original_tree.weighted_n_node_samples[pruned_node_id]
), "Weighted n_node samples does not match for node {}".format(pruned_node_id)
13 changes: 10 additions & 3 deletions treeple/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,10 @@
def test_iris(criterion, max_features, estimator):
# Check consistency on dataset iris.
clf = HonestTreeClassifier(
criterion=criterion, random_state=0, max_features=max_features, tree_estimator=estimator
criterion=criterion,
random_state=0,
max_features=max_features,
tree_estimator=estimator,
)
clf.fit(iris.data, iris.target)
score = accuracy_score(clf.predict(iris.data), iris.target)
@@ -54,8 +57,9 @@ def test_iris(criterion, max_features, estimator):
assert len(clf.structure_indices_) < len(iris.target)


def test_toy_accuracy():
clf = HonestTreeClassifier()
@pytest.mark.parametrize("honest_method", ["apply", "prune"])
def test_toy_accuracy(honest_method):
clf = HonestTreeClassifier(honest_method=honest_method)
X = np.ones((20, 4))
X[10:] *= -1
y = [0] * 10 + [1] * 10
@@ -175,6 +179,9 @@ def test_sklearn_compatible_estimator(estimator, check):
"check_class_weight_classifiers",
"check_classifier_multioutput",
"check_do_not_raise_errors_in_init_or_set_params",
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
]:
pytest.skip()
check(estimator)
5 changes: 5 additions & 0 deletions treeple/tree/tests/test_multiview.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,11 @@
]
)
def test_sklearn_compatible_estimator(estimator, check):
if check.func.__name__ in [
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
]:
pytest.skip()
check(estimator)


7 changes: 7 additions & 0 deletions treeple/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
@@ -229,6 +229,13 @@ def test_sklearn_compatible_estimator(estimator, check):
estimator, (PatchObliqueDecisionTreeClassifier, ExtraObliqueDecisionTreeClassifier)
) and check.func.__name__ in ["check_fit_score_takes_y"]:
pytest.skip()

if check.func.__name__ in [
"check_sample_weight_equivalence_on_sparse_data",
"check_sample_weight_equivalence_on_dense_data",
]:
pytest.skip()

check(estimator)


4 changes: 3 additions & 1 deletion treeple/tree/tests/test_unsupervised_tree.py
Original file line number Diff line number Diff line change
@@ -107,7 +107,9 @@ def test_sklearn_compatible_transformer(estimator, check):
# clustering accuracy is poor when using TwoMeans on 1 single tree
"check_clustering",
# sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weights_invariance",
"check_sample_weight_equivalence",
"check_sample_weight_equivalence_on_dense_data",
"check_sample_weight_equivalence_on_sparse_data",
# sample order is not preserved in predict
"check_methods_sample_order_invariance",
]:

0 comments on commit ab12ca9

Please sign in to comment.