From b7a434fb8e2fdb2221d7586571ad38e4f1fb8828 Mon Sep 17 00:00:00 2001 From: Daniel Lacina Date: Sat, 12 Jul 2025 09:44:07 -0500 Subject: [PATCH] refactored random forest regressor into reusable compoennts --- src/ensemble/base_forest_regressor.rs | 214 ++++++++++++++++++++++++ src/ensemble/mod.rs | 1 + src/ensemble/random_forest_regressor.rs | 152 +++-------------- src/tree/decision_tree_regressor.rs | 27 --- src/tree/mod.rs | 2 +- 5 files changed, 239 insertions(+), 157 deletions(-) create mode 100644 src/ensemble/base_forest_regressor.rs diff --git a/src/ensemble/base_forest_regressor.rs b/src/ensemble/base_forest_regressor.rs new file mode 100644 index 00000000..dc504446 --- /dev/null +++ b/src/ensemble/base_forest_regressor.rs @@ -0,0 +1,214 @@ +use rand::Rng; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::numbers::basenum::Number; +use crate::numbers::floatnum::FloatNumber; + +use crate::rand_custom::get_rng_impl; +use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Parameters of the Forest Regressor +/// Some parameters here are passed directly into base estimator. +pub struct BaseForestRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The number of trees in the forest. + pub n_trees: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// Number of random sample of predictors to use as split candidates. + pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, + #[cfg_attr(feature = "serde", serde(default))] + pub bootstrap: bool, + #[cfg_attr(feature = "serde", serde(default))] + pub splitter: Splitter, +} + +impl, Y: Array1> PartialEq + for BaseForestRegressor +{ + fn eq(&self, other: &Self) -> bool { + if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { + false + } else { + self.trees + .iter() + .zip(other.trees.iter()) + .all(|(a, b)| a == b) + } + } +} + +/// Forest Regressor +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct BaseForestRegressor< + TX: Number + FloatNumber + PartialOrd, + TY: Number, + X: Array2, + Y: Array1, +> { + trees: Option>>, + samples: Option>>, +} + +impl, Y: Array1> + BaseForestRegressor +{ + /// Build a forest of trees from the training set. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - the target class values + pub fn fit( + x: &X, + y: &Y, + parameters: BaseForestRegressorParameters, + ) -> Result, Failed> { + let (n_rows, num_attributes) = x.shape(); + + if n_rows != y.shape() { + return Err(Failed::fit("Number of rows in X should = len(y)")); + } + + let mtry = parameters + .m + .unwrap_or((num_attributes as f64).sqrt().floor() as usize); + + let mut rng = get_rng_impl(Some(parameters.seed)); + let mut trees: Vec> = Vec::new(); + + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + // TODO: use with_capacity here + maybe_all_samples = Some(Vec::new()); + } + + let mut samples: Vec = (0..n_rows).map(|_| 1).collect(); + + for _ in 0..parameters.n_trees { + if parameters.bootstrap { + samples = + BaseForestRegressor::::sample_with_replacement(n_rows, &mut rng); + } + + // keep samples is flag is on + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + + let params = BaseTreeRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), + splitter: parameters.splitter.clone(), + }; + let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?; + trees.push(tree); + } + + Ok(BaseForestRegressor { + trees: Some(trees), + samples: maybe_all_samples, + }) + } + + /// Predict class for `x` + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn predict(&self, x: &X) -> Result { + let mut result = Y::zeros(x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(i, self.predict_for_row(x, i)); + } + + Ok(result) + } + + fn predict_for_row(&self, x: &X, row: usize) -> TY { + let n_trees = self.trees.as_ref().unwrap().len(); + + let mut result = TY::zero(); + + for tree in self.trees.as_ref().unwrap().iter() { + result += tree.predict_for_row(x, row); + } + + result / TY::from_usize(n_trees).unwrap() + } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob(&self, x: &X) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = Y::zeros(n); + + for i in 0..n { + result.set(i, self.predict_for_row_oob(x, i)); + } + + Ok(result) + } + } + + fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { + let mut n_trees = 0; + let mut result = TY::zero(); + + for (tree, samples) in self + .trees + .as_ref() + .unwrap() + .iter() + .zip(self.samples.as_ref().unwrap()) + { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } + } + + // TODO: What to do if there are no oob trees? + result / TY::from(n_trees).unwrap() + } + + fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { + let mut samples = vec![0; nrows]; + for _ in 0..nrows { + let xi = rng.gen_range(0..nrows); + samples[xi] += 1; + } + samples + } +} diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index 8cebd5c5..dc030962 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -16,6 +16,7 @@ //! //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/) +mod base_forest_regressor; /// Random forest classifier pub mod random_forest_classifier; /// Random forest regressor diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index efc63d3d..0a8a888c 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,7 +43,6 @@ //! //! -use rand::Rng; use std::default::Default; use std::fmt::Debug; @@ -51,15 +50,12 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::{Failed, FailedError}; +use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters}; +use crate::error::Failed; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; - -use crate::rand_custom::get_rng_impl; -use crate::tree::decision_tree_regressor::{ - DecisionTreeRegressor, DecisionTreeRegressorParameters, -}; +use crate::tree::base_tree_regressor::Splitter; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -98,8 +94,7 @@ pub struct RandomForestRegressor< X: Array2, Y: Array1, > { - trees: Option>>, - samples: Option>>, + forest_regressor: Option>, } impl RandomForestRegressorParameters { @@ -159,14 +154,7 @@ impl, Y: Array1 for RandomForestRegressor { fn eq(&self, other: &Self) -> bool { - if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { - false - } else { - self.trees - .iter() - .zip(other.trees.iter()) - .all(|(a, b)| a == b) - } + self.forest_regressor == other.forest_regressor } } @@ -176,8 +164,7 @@ impl, Y: Array1 { fn new() -> Self { Self { - trees: Option::None, - samples: Option::None, + forest_regressor: Option::None, } } @@ -397,128 +384,35 @@ impl, Y: Array1 y: &Y, parameters: RandomForestRegressorParameters, ) -> Result, Failed> { - let (n_rows, num_attributes) = x.shape(); - - if n_rows != y.shape() { - return Err(Failed::fit("Number of rows in X should = len(y)")); - } - - let mtry = parameters - .m - .unwrap_or((num_attributes as f64).sqrt().floor() as usize); - - let mut rng = get_rng_impl(Some(parameters.seed)); - let mut trees: Vec> = Vec::new(); - - let mut maybe_all_samples: Option>> = Option::None; - if parameters.keep_samples { - // TODO: use with_capacity here - maybe_all_samples = Some(Vec::new()); - } - - for _ in 0..parameters.n_trees { - let samples: Vec = - RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng); - - // keep samples is flag is on - if let Some(ref mut all_samples) = maybe_all_samples { - all_samples.push(samples.iter().map(|x| *x != 0).collect()) - } - - let params = DecisionTreeRegressorParameters { - max_depth: parameters.max_depth, - min_samples_leaf: parameters.min_samples_leaf, - min_samples_split: parameters.min_samples_split, - seed: Some(parameters.seed), - }; - let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; - trees.push(tree); - } + let regressor_params = BaseForestRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + n_trees: parameters.n_trees, + m: parameters.m, + keep_samples: parameters.keep_samples, + seed: parameters.seed, + bootstrap: true, + splitter: Splitter::Best, + }; + let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?; Ok(RandomForestRegressor { - trees: Some(trees), - samples: maybe_all_samples, + forest_regressor: Some(forest_regressor), }) } /// Predict class for `x` /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { - let mut result = Y::zeros(x.shape().0); - - let (n, _) = x.shape(); - - for i in 0..n { - result.set(i, self.predict_for_row(x, i)); - } - - Ok(result) - } - - fn predict_for_row(&self, x: &X, row: usize) -> TY { - let n_trees = self.trees.as_ref().unwrap().len(); - - let mut result = TY::zero(); - - for tree in self.trees.as_ref().unwrap().iter() { - result += tree.predict_for_row(x, row); - } - - result / TY::from_usize(n_trees).unwrap() + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict(x) } /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. pub fn predict_oob(&self, x: &X) -> Result { - let (n, _) = x.shape(); - if self.samples.is_none() { - Err(Failed::because( - FailedError::PredictFailed, - "Need samples=true for OOB predictions.", - )) - } else if self.samples.as_ref().unwrap()[0].len() != n { - Err(Failed::because( - FailedError::PredictFailed, - "Prediction matrix must match matrix used in training for OOB predictions.", - )) - } else { - let mut result = Y::zeros(n); - - for i in 0..n { - result.set(i, self.predict_for_row_oob(x, i)); - } - - Ok(result) - } - } - - fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { - let mut n_trees = 0; - let mut result = TY::zero(); - - for (tree, samples) in self - .trees - .as_ref() - .unwrap() - .iter() - .zip(self.samples.as_ref().unwrap()) - { - if !samples[row] { - result += tree.predict_for_row(x, row); - n_trees += 1; - } - } - - // TODO: What to do if there are no oob trees? - result / TY::from(n_trees).unwrap() - } - - fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { - let mut samples = vec![0; nrows]; - for _ in 0..nrows { - let xi = rng.gen_range(0..nrows); - samples[xi] += 1; - } - samples + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict_oob(x) } } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 154ba2ef..86b99343 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -312,38 +312,11 @@ impl, Y: Array1> }) } - pub(crate) fn fit_weak_learner( - x: &X, - y: &Y, - samples: Vec, - mtry: usize, - parameters: DecisionTreeRegressorParameters, - ) -> Result, Failed> { - let tree_parameters = BaseTreeRegressorParameters { - max_depth: parameters.max_depth, - min_samples_leaf: parameters.min_samples_leaf, - min_samples_split: parameters.min_samples_split, - seed: parameters.seed, - splitter: Splitter::Best, - }; - let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples, mtry, tree_parameters)?; - Ok(Self { - tree_regressor: Some(tree), - }) - } - /// Predict regression value for `x`. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { self.tree_regressor.as_ref().unwrap().predict(x) } - - pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY { - self.tree_regressor - .as_ref() - .unwrap() - .predict_for_row(x, row) - } } #[cfg(test)] diff --git a/src/tree/mod.rs b/src/tree/mod.rs index b325e968..82937a5b 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -19,7 +19,7 @@ //! //! -mod base_tree_regressor; +pub(crate) mod base_tree_regressor; /// Classification tree for dependent variables that take a finite number of unordered values. pub mod decision_tree_classifier; /// Regression tree for for dependent variables that take continuous or ordered discrete values.