-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
FEA Implement pruning using honest subsample data to fit the leaves (#…
…286) * Honest pruning in honest tree classifier * TST turn off sklearn multi_output tag assignment --------- Signed-off-by: Adam Li <adam2392@gmail.com> Co-authored-by: Haoyin Xu <haoyinxu@gmail.com>
Showing
26 changed files
with
931 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
=========================================== | ||
Comparison of Decision Tree and Honest Tree | ||
=========================================== | ||
This example compares the :class:`treeple.tree.HonestTreeClassifier` from the | ||
``treeple`` library with the :class:`sklearn.tree.DecisionTreeClassifier` | ||
from scikit-learn on the Iris dataset. | ||
Both classifiers are fitted on the same dataset and their decision trees | ||
are plotted side by side. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
from sklearn import config_context | ||
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.tree import DecisionTreeClassifier, plot_tree | ||
|
||
from treeple.tree import HonestTreeClassifier | ||
|
||
# Load the iris dataset | ||
iris = load_iris() | ||
X, y = iris.data, iris.target | ||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.1, random_state=0) | ||
|
||
# Initialize classifiers | ||
max_features = 0.3 | ||
|
||
dishonest_clf = HonestTreeClassifier( | ||
honest_method=None, | ||
max_features=max_features, | ||
random_state=0, | ||
honest_prior="ignore", | ||
) | ||
honest_noprune_clf = HonestTreeClassifier( | ||
honest_method="apply", | ||
max_features=max_features, | ||
random_state=0, | ||
honest_prior="ignore", | ||
) | ||
honest_clf = HonestTreeClassifier(honest_method="prune", max_features=max_features, random_state=0) | ||
sklearn_clf = DecisionTreeClassifier(max_features=max_features, random_state=0) | ||
|
||
# Fit classifiers | ||
dishonest_clf.fit(X_train, y_train) | ||
honest_noprune_clf.fit(X_train, y_train) | ||
honest_clf.fit(X_train, y_train) | ||
sklearn_clf.fit(X_train, y_train) | ||
|
||
# Plotting the trees | ||
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(15, 5)) | ||
|
||
# .. note:: We skip parameter validation because internally the `plot_tree` | ||
# function checks if the estimator is a DecisionTreeClassifier | ||
# instance from scikit-learn, but the ``HonestTreeClassifier`` is | ||
# a subclass of a forked version of the DecisionTreeClassifier. | ||
|
||
# Plot HonestTreeClassifier tree | ||
ax = axes[2] | ||
with config_context(skip_parameter_validation=True): | ||
plot_tree(honest_clf, filled=True, ax=ax) | ||
ax.set_title("HonestTreeClassifier") | ||
|
||
# Plot HonestTreeClassifier tree | ||
ax = axes[1] | ||
with config_context(skip_parameter_validation=True): | ||
plot_tree(honest_noprune_clf, filled=False, ax=ax) | ||
ax.set_title("HonestTreeClassifier (No pruning)") | ||
|
||
# Plot HonestTreeClassifier tree | ||
ax = axes[0] | ||
with config_context(skip_parameter_validation=True): | ||
plot_tree(dishonest_clf, filled=False, ax=ax) | ||
ax.set_title("HonestTreeClassifier (Dishonest)") | ||
|
||
|
||
# Plot scikit-learn DecisionTreeClassifier tree | ||
plot_tree(sklearn_clf, filled=True, ax=axes[3]) | ||
axes[3].set_title("DecisionTreeClassifier") | ||
|
||
plt.show() | ||
|
||
# %% | ||
# Discussion | ||
# ---------- | ||
# The HonestTreeClassifier is a variant of the DecisionTreeClassifier that | ||
# provides honest inference. The honest inference is achieved by splitting the | ||
# dataset into two parts: the training set and the validation set. The training | ||
# set is used to build the tree, while the validation set is used to fit the | ||
# leaf nodes for posterior prediction. This results in calibrated posteriors | ||
# (see :ref:`sphx_glr_auto_examples_calibration_plot_overlapping_gaussians.py`). | ||
# | ||
# Compared to the ``honest_prior='apply'`` method, the ``honest_prior='prune'`` | ||
# method builds a tree that will not contain empty leaves, and also leverages | ||
# the validation set to check split conditions. Thus we see that the pruned | ||
# honest tree is significantly smaller than the regular decision tree. | ||
|
||
# %% | ||
# Evaluate predictions of the trees | ||
# --------------------------------- | ||
# When we do not prune, note that the honest tree will have empty leaves | ||
# that predict the prior. In this case, ``honest_prior='ignore'`` is used | ||
# to ignore these leaves when computing the posteriors, which will result | ||
# in a posterior that is ``np.nan``. | ||
|
||
# this is the same as a decision tree classifier that is trained on less data | ||
print("\nDishonest posteriors: ", dishonest_clf.predict_proba(X_val)) | ||
|
||
# this is the honest tree with empty leaves that predict the prior | ||
print("\nHonest tree without pruning: ", honest_noprune_clf.predict_proba(X_val)) | ||
|
||
# this is the honest tree that is pruned | ||
print("\nHonest tree with pruning: ", honest_clf.predict_proba(X_val)) | ||
|
||
# this is a regular decision tree classifier from sklearn | ||
print("\nDTC: ", sklearn_clf.predict_proba(X_val)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from ..._lib.sklearn.tree._criterion cimport Criterion | ||
from ..._lib.sklearn.tree._partitioner cimport shift_missing_values_to_left_if_required | ||
from ..._lib.sklearn.tree._splitter cimport SplitRecord, Splitter | ||
from ..._lib.sklearn.tree._tree cimport Node, ParentInfo, Tree | ||
from ..._lib.sklearn.utils._typedefs cimport float32_t, float64_t, int8_t, intp_t, uint8_t, uint32_t | ||
|
||
|
||
# for each node, keep track of the node index and the parent index | ||
# within the tree's node array | ||
cdef struct PruningRecord: | ||
intp_t node_idx | ||
intp_t start | ||
intp_t end | ||
float64_t lower_bound | ||
float64_t upper_bound | ||
|
||
|
||
# TODO: this may break the notion of feature importances, as we don't set the node's impurity | ||
# at the child nodes. | ||
cdef class HonestPruner(Splitter): | ||
cdef Tree tree # The tree to be pruned | ||
cdef intp_t capacity # The maximum number of nodes in the pruned tree | ||
cdef intp_t pos # The current position to split left/right children | ||
cdef intp_t n_missing # The number of missing values in the feature currently considered | ||
cdef uint8_t missing_go_to_left | ||
|
||
# TODO: only supports sparse for now. | ||
cdef const float32_t[:, :] X | ||
|
||
cdef int init( | ||
self, | ||
object X, | ||
const float64_t[:, ::1] y, | ||
const float64_t[:] sample_weight, | ||
const uint8_t[::1] missing_values_in_feature_mask, | ||
) except -1 | ||
|
||
# This function is not used, and should be disabled for pruners | ||
cdef int node_split( | ||
self, | ||
ParentInfo* parent_record, | ||
SplitRecord* split, | ||
) except -1 nogil | ||
|
||
cdef bint check_node_partition_conditions( | ||
self, | ||
SplitRecord* current_split, | ||
float64_t lower_bound, | ||
float64_t upper_bound | ||
) noexcept nogil | ||
|
||
cdef inline intp_t n_left_samples( | ||
self | ||
) noexcept nogil | ||
cdef inline intp_t n_right_samples( | ||
self | ||
) noexcept nogil | ||
|
||
cdef int partition_samples( | ||
self, | ||
intp_t node_idx, | ||
) noexcept nogil |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
tree_extension_metadata = { | ||
'_honest_prune': | ||
{'sources': ['_honest_prune.pyx'], | ||
'override_options': ['cython_language=cpp', 'optimization=3']}, | ||
} | ||
|
||
foreach ext_name, ext_dict : tree_extension_metadata | ||
py.extension_module( | ||
ext_name, | ||
ext_dict.get('sources'), | ||
dependencies: [np_dep], | ||
override_options : ext_dict.get('override_options', []), | ||
c_args: c_args, | ||
cython_args: cython_c_args, | ||
subdir: 'treeple/tree/honesty', | ||
install: true, | ||
) | ||
endforeach | ||
|
||
|
||
py.install_sources( | ||
subdir: 'treeple/tree/honesty' # Folder relative to site-packages to install to | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
|
||
from treeple.tree import HonestTreeClassifier | ||
|
||
|
||
def test_honest_tree_pruning(): | ||
"""Test honest tree with pruning to ensure no empty leaves.""" | ||
rng = np.random.default_rng(1234) | ||
|
||
n_samples = 1000 | ||
X = rng.standard_normal(size=(n_samples, 100)) | ||
X[n_samples // 2 :] *= -1 | ||
y = [0] * (n_samples // 2) + [1] * (n_samples // 2) | ||
|
||
clf = HonestTreeClassifier(honest_method="prune", max_features="sqrt", random_state=0) | ||
clf = clf.fit(X, y) | ||
|
||
nonprune_clf = HonestTreeClassifier( | ||
honest_method="apply", max_features="sqrt", random_state=0, honest_prior="ignore" | ||
) | ||
nonprune_clf = nonprune_clf.fit(X, y) | ||
|
||
assert ( | ||
nonprune_clf.tree_.max_depth >= clf.tree_.max_depth | ||
), f"{nonprune_clf.tree_.max_depth} <= {clf.tree_.max_depth}" | ||
# assert np.all(clf.tree_.children_left != -1) | ||
|
||
# Access the original and pruned trees' attributes | ||
original_tree = nonprune_clf.tree_ | ||
pruned_tree = clf.tree_ | ||
|
||
# Ensure the pruned tree has fewer or equal nodes | ||
assert ( | ||
pruned_tree.node_count < original_tree.node_count | ||
), "Pruned tree has more nodes than the original tree" | ||
|
||
# Ensure the pruned tree has no empty leaves | ||
assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0), pruned_tree.value.sum(axis=(1, 2)) | ||
# assert np.all(original_tree.value.sum(axis=(1,2)) > 0), original_tree.value.sum(axis=(1,2)) | ||
assert np.all(pruned_tree.value.sum(axis=(1, 2)) > 0) > np.all( | ||
original_tree.value.sum(axis=(1, 2)) > 0 | ||
) | ||
|
||
# test that the first three nodes are the same, since these are unlikely to be | ||
# pruned, and should remain invariant. | ||
# | ||
# Note: pruning the tree will have the node_ids change since the tree is | ||
# ordered via DFS. | ||
for pruned_node_id in range(3): | ||
pruned_left_child = pruned_tree.children_left[pruned_node_id] | ||
pruned_right_child = pruned_tree.children_right[pruned_node_id] | ||
|
||
# Check if the pruned node exists in the original tree | ||
assert ( | ||
pruned_left_child in original_tree.children_left | ||
), "Left child node of pruned tree not found in original tree" | ||
assert ( | ||
pruned_right_child in original_tree.children_right | ||
), "Right child node of pruned tree not found in original tree" | ||
|
||
# Check if the node's parameters match for non-leaf nodes | ||
if pruned_left_child != -1: | ||
assert ( | ||
pruned_tree.feature[pruned_node_id] == original_tree.feature[pruned_node_id] | ||
), "Feature does not match for node {}".format(pruned_node_id) | ||
assert ( | ||
pruned_tree.threshold[pruned_node_id] == original_tree.threshold[pruned_node_id] | ||
), "Threshold does not match for node {}".format(pruned_node_id) | ||
assert ( | ||
pruned_tree.weighted_n_node_samples[pruned_node_id] | ||
== original_tree.weighted_n_node_samples[pruned_node_id] | ||
), "Weighted n_node samples does not match for node {}".format(pruned_node_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters