Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 318 additions & 0 deletions src/ensemble/extra_trees_regressor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
//! # Extra Trees Regressor
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
//!
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
//! reduce the variance of the model and often make the training process faster.
//!
//! The two key differences from a standard Random Forest are:
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
//!
//! See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::extra_trees_regressor::*;
//!
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
//! let x = DenseMatrix::from_2d_array(&[
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ];
//!
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>

use std::default::Default;
use std::fmt::Debug;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Extra Trees Regressor
/// Some parameters here are passed directly into base estimator.
pub struct ExtraTreesRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}

/// Extra Trees Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ExtraTreesRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}

impl ExtraTreesRegressorParameters {
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
/// The number of trees in the forest.
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
/// Number of random sample of predictors to use as split candidates.
pub fn with_m(mut self, m: usize) -> Self {
self.m = Some(m);
self
}

/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}

/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for ExtraTreesRegressorParameters {
fn default() -> Self {
ExtraTreesRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 0,
}
}
}

impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}

fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
ExtraTreesRegressor::fit(x, y, parameters)
}
}

impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}

impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ExtraTreesRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: ExtraTreesRegressorParameters,
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: false,
splitter: Splitter::Random,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;

Ok(ExtraTreesRegressor {
forest_regressor: Some(forest_regressor),
})
}

/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}

/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;

#[test]
fn test_extra_trees_regressor_fit_predict() {
// Use a simpler, more predictable dataset for unit testing.
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
&[13., 14.],
&[15., 16.],
])
.unwrap();
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];

let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);

let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();

assert_eq!(y_hat.len(), y.len());
// A basic check to ensure the model is learning something.
// The error should be significantly less than the variance of y.
let mse = mean_squared_error(&y, &y_hat);
// With this simple dataset, the error should be very low.
assert!(mse < 1.0);
}

#[test]
fn test_fit_predict_higher_dims() {
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
let x = DenseMatrix::from_2d_array(&[
// The 3rd column is the important one. The rest are noise.
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
])
.unwrap();
let y = vec![10., 20., 30., 40., 55., 65.];

let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);

let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();

assert_eq!(y_hat.len(), y.len());

let mse = mean_squared_error(&y, &y_hat);

// The model should be able to learn this simple relationship perfectly,
// ignoring the noise features. The MSE should be very low.
assert!(mse < 1.0);
}

#[test]
fn test_reproducibility() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
])
.unwrap();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];

let params = ExtraTreesRegressorParameters::default().with_seed(42);

let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat1 = regressor1.predict(&x).unwrap();

let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat2 = regressor2.predict(&x).unwrap();

assert_eq!(y_hat1, y_hat2);
}
}
1 change: 1 addition & 0 deletions src/ensemble/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)

mod base_forest_regressor;
pub mod extra_trees_regressor;
/// Random forest classifier
pub mod random_forest_classifier;
/// Random forest regressor
Expand Down
Loading