Skip to content

Latest commit

 

History

History

decision_tree

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Decision Trees

Decision Tree Example

Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data. DTs are highly interpretable, capable of achieving high accuracy for many tasks while requiring little data preparation.

Creating a decision tree – Recursive Binary Splitting

Growing a tree involves continuously splitting the data into subsets to minimize some cost function. At each step, all features are considered, and different split points are tried and tested using a cost function. The split with the lowest cost is then selected. The process gets repeated until some stopping point (mentioned later). This algorithm is recursive in nature as the groups formed after each split can be subdivided using the same strategy.

Cost of a split

The cost of a split determines how good it is to split at that specific feature value. For regression cost functions like the sum of squared errors or the standard deviation are used.

For classification the Gini Index is used:

Where J is the set of all classes, and pi is the fraction of items belonging to class i. A split should ideally have an error value of zero, which means that the resulting groups contain only one class. The worst gini purity is 0.5, which occurs when the classes in a group are split 50-50.

When should you stop splitting?

Now you might ask when to stop growing the tree? This is an important question because if we would keep splitting and splitting the decision tree would get huge, quite fast. Such complex trees are slow and dent to overfit. Therefore, we will set a predefined stopping criterion to halt the construction of the decision tree.

The two most common stopping methods are:

  • Minimum count of training examples assigned to a leaf node, e.g., if there are less than 10 training points, stop splitting.
  • Maximum depth (maximum length from root to leaf)

A larger tree might perform better but is also more prone to overfit. Having too large of a min count or too small of a maximum depth could stop the training to early and result in bad performance.

Pruning

Pruning is a technique that reduces the size of decision trees by removing sections of the tree that have little importance. Pruning reduces the complexity of the final model, and hence improves predictive accuracy by reducing overfitting.

There are multiple pruning techniques available. In this article, we'll focus on two:

  • Reduced error pruning
  • Cost complexity pruning

Reduced error pruning

One of the simplest forms of pruning is reduced error pruning. Starting at the leaves, each node is replaced with its most popular class. If the loss function is not negatively affected, then the change is kept, else it is reverted. While a somewhat naive approach to pruning, reduced error pruning has the advantage of speed and simplicity.

Cost complexity pruning

Cost complexity pruning, also known as weakest link pruning, is a more sophisticated pruning method. It creates a series of trees T0 to Tn where T0 is the initial tree, and Tn is the root alone. The tree at step i is created by removing a subtree from tree i-1 and replacing it with a leaf node.

For more information, check out:

Decision trees for both classification and regression are super easy to use in Scikit-Learn.

To load in the Iris data-set, create a decision tree object, and train it on the Iris data, the following code can be used:

from sklearn.datasets import load_iris
from sklearn import tree

iris = load_iris()
X = iris.data
y = iris.target

clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

Once trained, you can plot the tree with the plot_tree function:

tree.plot_tree(clf)

Plot Tree Output

The tree can also be exported in Graphviz format using the export_graphviz method. The graphviz python wrapper can be installed using conda or pip.

pip install graphviz
or
conda install python-graphviz

Below is an example graphviz export of the above tree.

import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None, 
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  
                     special_characters=True)
graph = graphviz.Source(dot_data) 
graph.render("iris") 

Iris Decision Tree

Iris Decision Surface

Alternatively, the tree can also be exported in textual format with the export_text method.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

For regression, use a DecisionTreeRegressor instead of the DecisionTreeClassifier.

from sklearn.datasets import load_diabetes
from sklearn.tree import DecisionTreeRegressor

X, y = load_diabetes(return_X_y=True)
regressor = DecisionTreeRegressor(random_state=0)
regressor = regressor.fit(X, y)

Code

Credit / Other resources