diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c94e246bb..cda174240 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -49,7 +49,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] os: ["ubuntu-latest"] steps: @@ -103,7 +103,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] os: ["ubuntu-latest"] steps: @@ -143,7 +143,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] os: ["ubuntu-latest"] steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a357ade9..9a73e9293 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 2.4.0 +## Improvements +- Replace random forest from pyrfr with random forest from sklearn (#1246) + # 2.3.1 ## Bugfixes diff --git a/README.md b/README.md index e2f0e2b29..8fda28491 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,6 @@ conda create -n SMAC python=3.10 conda activate SMAC ``` -Install swig: -``` -conda install gxx_linux-64 gcc_linux-64 swig -``` - Install SMAC via PyPI: ``` pip install smac @@ -63,6 +58,20 @@ git clone https://github.com/automl/SMAC3.git && cd SMAC3 make install-dev ``` +## Running SMAC with pyrfr +starting from 2.4.0, SMAC uses random forest from [sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html) +instead of random forest from [pyrfr](https://pypi.org/project/pyrfr/) as the default surrogate model for HPO tasks. +However, you could still use the old pyrfr surrogate model by calling `smac.facade.old.HyperparameterOptimizationRFRFacade` +and `smac.facade.old.MultiFidelityRFRFacade` + +To work with pyrfr, you need to first install gcc, gxx, and swig: +``` +conda install gxx_linux-64 gcc_linux-64 swig +``` +then install smac with the pyrfr option: +``` +pip install smac[pyrfr] +``` ## Minimal Example diff --git a/setup.py b/setup.py index 6ac234efe..1eca8d910 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,9 @@ def read_file(filepath: str) -> str: extras_require = { + "pyrfr": [ + "pyrfr>=0.9.0", + ], "dev": [ "setuptools", "types-setuptools", @@ -79,8 +82,7 @@ def read_file(filepath: str) -> str: "pynisher>=1.0.0", "ConfigSpace>=1.0.0", "joblib", - "scikit-learn>=1.1.2", - "pyrfr>=0.9.0", + "scikit-learn>=1.6.1", "dask[distributed]", "dask_jobqueue>=0.8.2", "emcee>=3.0.0", diff --git a/smac/facade/old/__init__.py b/smac/facade/old/__init__.py new file mode 100644 index 000000000..7f73b5e78 --- /dev/null +++ b/smac/facade/old/__init__.py @@ -0,0 +1,9 @@ +from smac.facade.old.hyperparameter_optimization_facade_pyrfr import ( + HyperparameterOptimizationRFRFacade, +) +from smac.facade.old.multi_fidelity_facade_pyrfr import MultiFidelityRFRFacade + +__all__ = [ + "HyperparameterOptimizationRFRFacade", + "MultiFidelityRFRFacade", +] diff --git a/smac/facade/old/hyperparameter_optimization_facade_pyrfr.py b/smac/facade/old/hyperparameter_optimization_facade_pyrfr.py new file mode 100644 index 000000000..f833822de --- /dev/null +++ b/smac/facade/old/hyperparameter_optimization_facade_pyrfr.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from smac.facade.hyperparameter_optimization_facade import ( + HyperparameterOptimizationFacade, +) +from smac.model.random_forest.pyrfr.random_forest_pyrfr import PyrfrRandomForest +from smac.scenario import Scenario + + +class HyperparameterOptimizationRFRFacade(HyperparameterOptimizationFacade): + @staticmethod + def get_model( # type: ignore + scenario: Scenario, + *, + n_trees: int = 10, + ratio_features: float = 1.0, + min_samples_split: int = 2, + min_samples_leaf: int = 1, + max_depth: int = 2**20, + bootstrapping: bool = True, + ) -> PyrfrRandomForest: + """Returns a random forest as surrogate model. + + Parameters + ---------- + n_trees : int, defaults to 10 + The number of trees in the random forest. + ratio_features : float, defaults to 5.0 / 6.0 + The ratio of features that are considered for splitting. + min_samples_split : int, defaults to 3 + The minimum number of data points to perform a split. + min_samples_leaf : int, defaults to 3 + The minimum number of data points in a leaf. + max_depth : int, defaults to 20 + The maximum depth of a single tree. + bootstrapping : bool, defaults to True + Enables bootstrapping. + """ + return PyrfrRandomForest( + log_y=True, + n_trees=n_trees, + bootstrapping=bootstrapping, + ratio_features=ratio_features, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + configspace=scenario.configspace, + instance_features=scenario.instance_features, + seed=scenario.seed, + ) diff --git a/smac/facade/old/multi_fidelity_facade_pyrfr.py b/smac/facade/old/multi_fidelity_facade_pyrfr.py new file mode 100644 index 000000000..eef3a4629 --- /dev/null +++ b/smac/facade/old/multi_fidelity_facade_pyrfr.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from smac.facade.multi_fidelity_facade import MultiFidelityFacade +from smac.facade.old.hyperparameter_optimization_facade_pyrfr import ( + HyperparameterOptimizationRFRFacade, +) + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + + +class MultiFidelityRFRFacade(MultiFidelityFacade, HyperparameterOptimizationRFRFacade): + pass diff --git a/smac/model/random_forest/pyrfr/__init__.py b/smac/model/random_forest/pyrfr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/smac/model/random_forest/pyrfr/random_forest_pyrfr.py b/smac/model/random_forest/pyrfr/random_forest_pyrfr.py new file mode 100644 index 000000000..6efa21f5b --- /dev/null +++ b/smac/model/random_forest/pyrfr/random_forest_pyrfr.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +from ConfigSpace import ConfigurationSpace + +try: + from pyrfr import regression + from pyrfr.regression import binary_rss_forest as BinaryForest + from pyrfr.regression import default_data_container as DataContainer +except ImportError as e: + import warnings + + warnings.warn( + "You are using SMAC RandomForest with pyrfr." + "However, the pyrfr package is not installed. " + "Please install pyrfr with the following commands:" + "conda install gxx_linux-64 gcc_linux-64 swig" + "pip install pyrfr>=0.9.0" + ) + raise e + + +from smac.constants import N_TREES, VERY_SMALL_NUMBER +from smac.model.random_forest import AbstractRandomForest + +__copyright__ = "Copyright 2022, automl.org" +__license__ = "3-clause BSD" + + +class PyrfrRandomForest(AbstractRandomForest): + """Random forest that takes instance features into account. + + Parameters + ---------- + n_trees : int, defaults to `N_TREES` + The number of trees in the random forest. + n_points_per_tree : int, defaults to -1 + Number of points per tree. If the value is smaller than 0, the number of samples will be used. + ratio_features : float, defaults to 5.0 / 6.0 + The ratio of features that are considered for splitting. + min_samples_split : int, defaults to 3 + The minimum number of data points to perform a split. + min_samples_leaf : int, defaults to 3 + The minimum number of data points in a leaf. + max_depth : int, defaults to 2**20 + The maximum depth of a single tree. + eps_purity : float, defaults to 1e-8 + The minimum difference between two target values to be considered. + max_nodes : int, defaults to 2**20 + The maximum total number of nodes in a tree. + bootstrapping : bool, defaults to True + Enables bootstrapping. + log_y: bool, defaults to False + The y values (passed to this random forest) are expected to be log(y) transformed. + This will be considered during predicting. + instance_features : dict[str, list[int | float]] | None, defaults to None + Features (list of int or floats) of the instances (str). The features are incorporated into the X data, + on which the model is trained on. + pca_components : float, defaults to 7 + Number of components to keep when using PCA to reduce dimensionality of instance features. + seed : int + """ + + def __init__( + self, + configspace: ConfigurationSpace, + n_trees: int = N_TREES, + n_points_per_tree: int = -1, + ratio_features: float = 5.0 / 6.0, + min_samples_split: int = 3, + min_samples_leaf: int = 3, + max_depth: int = 2**20, + eps_purity: float = 1e-8, + max_nodes: int = 2**20, + bootstrapping: bool = True, + log_y: bool = False, + instance_features: dict[str, list[int | float]] | None = None, + pca_components: int | None = 7, + seed: int = 0, + ) -> None: + super().__init__( + configspace=configspace, + instance_features=instance_features, + pca_components=pca_components, + seed=seed, + ) + + max_features = 0 if ratio_features > 1.0 else max(1, int(len(self._types) * ratio_features)) + + self._rf_opts = regression.forest_opts() + self._rf_opts.num_trees = n_trees + self._rf_opts.do_bootstrapping = bootstrapping + self._rf_opts.tree_opts.max_features = max_features + self._rf_opts.tree_opts.min_samples_to_split = min_samples_split + self._rf_opts.tree_opts.min_samples_in_leaf = min_samples_leaf + self._rf_opts.tree_opts.max_depth = max_depth + self._rf_opts.tree_opts.epsilon_purity = eps_purity + self._rf_opts.tree_opts.max_num_nodes = max_nodes + self._rf_opts.compute_law_of_total_variance = False + self._rf: BinaryForest | None = None + self._log_y = log_y + + # Case to `int` incase we get an `np.integer` type + self._rng = regression.default_random_engine(int(seed)) + + self._n_trees = n_trees + self._n_points_per_tree = n_points_per_tree + self._ratio_features = ratio_features + self._min_samples_split = min_samples_split + self._min_samples_leaf = min_samples_leaf + self._max_depth = max_depth + self._eps_purity = eps_purity + self._max_nodes = max_nodes + self._bootstrapping = bootstrapping + + # This list well be read out by save_iteration() in the solver + # self._hypers = [ + # n_trees, + # max_nodes, + # bootstrapping, + # n_points_per_tree, + # ratio_features, + # min_samples_split, + # min_samples_leaf, + # max_depth, + # eps_purity, + # self._seed, + # ] + + @property + def meta(self) -> dict[str, Any]: # noqa: D102 + meta = super().meta + meta.update( + { + "n_trees": self._n_trees, + "n_points_per_tree": self._n_points_per_tree, + "ratio_features": self._ratio_features, + "min_samples_split": self._min_samples_split, + "min_samples_leaf": self._min_samples_leaf, + "max_depth": self._max_depth, + "eps_purity": self._eps_purity, + "max_nodes": self._max_nodes, + "bootstrapping": self._bootstrapping, + "pca_components": self._pca_components, + } + ) + + return meta + + def _train(self, X: np.ndarray, y: np.ndarray) -> PyrfrRandomForest: + X = self._impute_inactive(X) + y = y.flatten() + + # self.X = X + # self.y = y.flatten() + + if self._n_points_per_tree <= 0: + self._rf_opts.num_data_points_per_tree = X.shape[0] + else: + self._rf_opts.num_data_points_per_tree = self._n_points_per_tree + + self._rf = regression.binary_rss_forest() + self._rf.options = self._rf_opts + + data = self._init_data_container(X, y) + self._rf.fit(data, rng=self._rng) + + return self + + def _init_data_container(self, X: np.ndarray, y: np.ndarray) -> DataContainer: + """Fills a pyrfr default data container s.t. the forest knows categoricals and bounds for continous data. + + Parameters + ---------- + X : np.ndarray [#samples, #hyperparameter + #features] + Input data points. + Y : np.ndarray [#samples, #objectives] + The corresponding target values. + + Returns + ------- + data : DataContainer + The filled data container that pyrfr can interpret. + """ + # Retrieve the types and the bounds from the ConfigSpace + data = regression.default_data_container(X.shape[1]) + + for i, (mn, mx) in enumerate(self._bounds): + if np.isnan(mx): + data.set_type_of_feature(i, mn) + else: + data.set_bounds_of_feature(i, mn, mx) + + for row_X, row_y in zip(X, y): + data.add_data_point(row_X, row_y) + + return data + + def _predict( + self, + X: np.ndarray, + covariance_type: str | None = "diagonal", + ) -> tuple[np.ndarray, np.ndarray | None]: + if len(X.shape) != 2: + raise ValueError("Expected 2d array, got %dd array!" % len(X.shape)) + + if X.shape[1] != len(self._types): + raise ValueError("Rows in X should have %d entries but have %d!" % (len(self._types), X.shape[1])) + + if covariance_type != "diagonal": + raise ValueError("`covariance_type` can only take `diagonal` for this model.") + + assert self._rf is not None + X = self._impute_inactive(X) + + if self._log_y: + all_preds = [] + third_dimension = 0 + + # Gather data in a list of 2d arrays and get statistics about the required size of the 3d array + for row_X in X: + preds_per_tree = self._rf.all_leaf_values(row_X) + all_preds.append(preds_per_tree) + max_num_leaf_data = max(map(len, preds_per_tree)) + third_dimension = max(max_num_leaf_data, third_dimension) + + # Transform list of 2d arrays into a 3d array + preds_as_array = np.zeros((X.shape[0], self._rf_opts.num_trees, third_dimension)) * np.nan + for i, preds_per_tree in enumerate(all_preds): + for j, pred in enumerate(preds_per_tree): + preds_as_array[i, j, : len(pred)] = pred + + # Do all necessary computation with vectorized functions + preds_as_array = np.log(np.nanmean(np.exp(preds_as_array), axis=2) + VERY_SMALL_NUMBER) + + # Compute the mean and the variance across the different trees + means = preds_as_array.mean(axis=1) + vars_ = preds_as_array.var(axis=1) + else: + means, vars_ = [], [] + for row_X in X: + mean_, var = self._rf.predict_mean_var(row_X) + means.append(mean_) + vars_.append(var) + + means = np.array(means) + vars_ = np.array(vars_) + + return means.reshape((-1, 1)), vars_.reshape((-1, 1)) + + def predict_marginalized(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Predicts mean and variance marginalized over all instances. + + Note + ---- + The method is random forest specific and follows the SMAC2 implementation. It requires + no distribution assumption to marginalize the uncertainty estimates. + + Parameters + ---------- + X : np.ndarray [#samples, #hyperparameter + #features] + Input data points. + + Returns + ------- + means : np.ndarray [#samples, 1] + The predictive mean. + vars : np.ndarray [#samples, 1] + The predictive variance. + """ + if self._n_features == 0: + mean_, var = self.predict(X) + assert var is not None + + var[var < self._var_threshold] = self._var_threshold + var[np.isnan(var)] = self._var_threshold + + return mean_, var + + assert self._instance_features is not None + + if len(X.shape) != 2: + raise ValueError("Expected 2d array, got %dd array!" % len(X.shape)) + + if X.shape[1] != len(self._bounds): + raise ValueError("Rows in X should have %d entries but have %d!" % (len(self._bounds), X.shape[1])) + + assert self._rf is not None + X = self._impute_inactive(X) + + X_feat = list(self._instance_features.values()) + dat_ = self._rf.predict_marginalized_over_instances_batch(X, X_feat, self._log_y) + dat_ = np.array(dat_) + + # 3. compute statistics across trees + mean_ = dat_.mean(axis=1) + var = dat_.var(axis=1) + + if var is None: + raise RuntimeError("The variance must not be none.") + + var[var < self._var_threshold] = self._var_threshold + + if len(mean_.shape) == 1: + mean_ = mean_.reshape((-1, 1)) + if len(var.shape) == 1: + var = var.reshape((-1, 1)) + + return mean_, var diff --git a/smac/model/random_forest/random_forest.py b/smac/model/random_forest/random_forest.py index 3500d03ce..e60ca088e 100644 --- a/smac/model/random_forest/random_forest.py +++ b/smac/model/random_forest/random_forest.py @@ -1,20 +1,640 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable, Iterable, Tuple + +import threading +from itertools import product import numpy as np from ConfigSpace import ConfigurationSpace -from pyrfr import regression -from pyrfr.regression import binary_rss_forest as BinaryForest -from pyrfr.regression import default_data_container as DataContainer - -from smac.constants import N_TREES, VERY_SMALL_NUMBER +from scipy.sparse import issparse +from sklearn.ensemble._base import _partition_estimators +from sklearn.ensemble._forest import ForestRegressor +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree._tree import DTYPE +from sklearn.utils.parallel import Parallel, delayed +from sklearn.utils.validation import check_is_fitted, validate_data + +from smac.constants import N_TREES from smac.model.random_forest import AbstractRandomForest __copyright__ = "Copyright 2025, Leibniz University Hanover, Institute of AI" __license__ = "3-clause BSD" +def estimator_predict(predict: Callable, X: np.ndarray, results: np.ndarray, tree_idx: int) -> None: + """ + Collect predictions from a single estimator. + + Parameters + ---------- + predict: Callable + the prediction function, in this scenario, it is the prediction function of each tree + X: np.ndarray [#samples, #hyperparameter] + input features + results: np.ndarray [#samples, #estimators] + output values from all the predictors + tree_idx: + estimator index + """ + prediction = predict(X, check_input=False) + results[:, tree_idx] = prediction # Populate the corresponding column + + +def accumulate_predict_over_instances( + predict: Callable, + X: np.ndarray, + X_instance_feat: np.ndarray, + results: np.ndarray, + tree_idx: int, + n_instances: int, + lock: threading.Lock, +) -> None: + """ + Collect predictions from a single estimator. However, we sum the results from all instances + + Parameters + ---------- + predict: Callable + the prediction function, in this scenario, it is the prediction function of each tree + X: np.ndarray [#samples, #hyperparameter] + Input data points. + X_instance_feat: np.ndarray [#instance, #features], + Features (np.ndarray) of the instances (str). The features are incorporated into the X data, + on which the model is trained on. + + results: np.ndarray [#samples, #estimators] + output values from all the predictors + tree_idx: int + tree index + n_instances: int + number of instance + lock: threading.Lock + threading lock + """ + X_instance_feat_ = np.tile(X_instance_feat[None, :], (len(X), 1)) + prediction = predict(np.concatenate([X, X_instance_feat_], axis=1), check_input=False) + with lock: + results[:, tree_idx,] += ( + prediction / n_instances + ) + + +class EPMRandomForest(ForestRegressor): + def __init__( + self, + n_estimators: int = 100, + *, + log_y: bool = False, + cross_trees_variance: bool = False, + criterion: str = "squared_error", + splitter: str = "random", + max_depth: int | None = None, + min_samples_split: int = 2, + min_samples_leaf: int = 1, + min_weight_fraction_leaf: float = 0.0, + max_features: float = 1.0, + max_leaf_nodes: int | None = None, + min_impurity_decrease: float = 0.0, + bootstrap: bool = False, + oob_score: bool = False, + n_jobs: int | None = None, + random_state: int | None = None, + verbose: int = 0, + warm_start: bool = False, + ccp_alpha: float = 0.0, + max_samples: int | float | None = None, + monotonic_cst: Iterable | None = None, + ) -> None: + """A decision tree regressor. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_estimators : int, default=100 + The number of trees in the forest. + + .. versionchanged:: 0.22 + The default value of ``n_estimators`` changed from 10 to 100 + in 0.22. + + criterion : {"squared_error", "absolute_error", "friedman_mse", "poisson"}, \ + default="squared_error" + The function to measure the quality of a split. Supported criteria + are "squared_error" for the mean squared error, which is equal to + variance reduction as feature selection criterion and minimizes the L2 + loss using the mean of each terminal node, "friedman_mse", which uses + mean squared error with Friedman's improvement score for potential + splits, "absolute_error" for the mean absolute error, which minimizes + the L1 loss using the median of each terminal node, and "poisson" which + uses reduction in Poisson deviance to find splits. + Training using "absolute_error" is significantly slower + than when using "squared_error". + + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + .. versionadded:: 1.0 + Poisson criterion. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : {"sqrt", "log2", None}, int or float, default=1.0 + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None or 1.0, then `max_features=n_features`. + + .. note:: + The default of 1.0 is equivalent to bagged trees and more + randomness can be achieved by setting smaller values, e.g. 0.3. + + .. versionchanged:: 1.1 + The default of `max_features` changed from `"auto"` to 1.0. + + .. deprecated:: 1.1 + The `"auto"` option was deprecated in 1.1 and will be removed + in 1.3. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_leaf_nodes : int, default=None + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + bootstrap : bool, default=True + Whether bootstrap samples are used when building trees. If False, the + whole dataset is used to build each tree. + + oob_score : bool, default=False + Whether to use out-of-bag samples to estimate the generalization score. + Only available if bootstrap=True. + + n_jobs : int, default=None + The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, + :meth:`decision_path` and :meth:`apply` are all parallelized over the + trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` + context. ``-1`` means using all processors. See :term:`Glossary + ` for more details. + + random_state : int, RandomState instance or None, default=None + Controls both the randomness of the bootstrapping of the samples used + when building trees (if ``bootstrap=True``) and the sampling of the + features to consider when looking for the best split at each node + (if ``max_features < n_features``). + See :term:`Glossary ` for details. + + verbose : int, default=0 + Controls the verbosity when fitting and predicting. + + warm_start : bool, default=False + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. See :term:`Glossary ` and + :ref:`gradient_boosting_warm_start` for details. + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + max_samples : int or float, default=None + If bootstrap is True, the number of samples to draw from X + to train each base estimator. + + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0.0, 1.0]`. + + .. versionadded:: 0.22 + + criterion : {"squared_error", "friedman_mse", "absolute_error", \ + "poisson"}, default="squared_error" + The function to measure the quality of a split. Supported criteria + are "squared_error" for the mean squared error, which is equal to + variance reduction as feature selection criterion and minimizes the L2 + loss using the mean of each terminal node, "friedman_mse", which uses + mean squared error with Friedman's improvement score for potential + splits, "absolute_error" for the mean absolute error, which minimizes + the L1 loss using the median of each terminal node, and "poisson" which + uses reduction in Poisson deviance to find splits. + + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + .. versionadded:: 0.24 + Poisson deviance criterion. + + splitter : {"best", "random"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : int, float or {"auto", "sqrt", "log2"}, default=None + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + .. deprecated:: 1.1 + The `"auto"` option was deprecated in 1.1 and will be removed + in 1.3. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of the estimator. The features are always + randomly permuted at each split, even if ``splitter`` is set to + ``"best"``. When ``max_features < n_features``, the algorithm will + select ``max_features`` at random at each split before finding the best + split among them. But the best found split may vary across different + runs, even if ``max_features=n_features``. That is the case, if the + improvement of the criterion is identical for several splits and one + split has to be selected at random. To obtain a deterministic behaviour + during fitting, ``random_state`` has to be fixed to an integer. + See :term:`Glossary ` for details. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + """ + super().__init__( + DecisionTreeRegressor(), + n_estimators, + estimator_params=( + "criterion", + "max_depth", + "min_samples_split", + "min_samples_leaf", + "min_weight_fraction_leaf", + "max_features", + "max_leaf_nodes", + "min_impurity_decrease", + "random_state", + "ccp_alpha", + "monotonic_cst", + ), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + max_samples=max_samples, + ) + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.min_impurity_decrease = min_impurity_decrease + self.ccp_alpha = ccp_alpha + self.monotonic_cst = monotonic_cst + self.splitter = splitter + self.log_y = log_y + self.cross_trees_variance = cross_trees_variance + + def fit(self, X: np.ndarray, y: np.ndarray, sample_weight=None) -> None: # type: ignore + """ + Build a forest of trees from the training set (X, y). In additional to the vanilla RF fitting process, we also + need to edit the estimators' parameters after the fitting process when self.log_y is True. This ensures that the + model performance consistently compared to the pyrfr version. To compute the means of all the values, we first + need to recover the log scaled values stored in the leave nodes to their raw scale and then compute the mean + over those values. This mean value will then transformed back to the log scale. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, its dtype will be converted + to ``dtype=np.float32``. If a sparse matrix is provided, it will be + converted into a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels in classification, real numbers in + regression). + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. In the case of + classification, splits are also ignored if they would result in any + single class carrying a negative weight in either child node. + + Returns + ------- + self : object + Fitted estimator. + """ + assert sample_weight is None, "Sample weights are not supported" + super().fit(X=X, y=y, sample_weight=sample_weight) + + self.trainX = X + self.trainY = y + if self.log_y: + for tree, samples_idx in zip(self.estimators_, self.estimators_samples_): + curX = X[samples_idx] + curY = y[samples_idx] + preds = tree.apply(curX) + for k in np.unique(preds): + tree.tree_.value[k, 0, 0] = np.log(np.exp(curY[preds == k]).mean()) + + def all_trees_pred(self, X: np.ndarray) -> np.ndarray: + """ + This function is used to parally predict the target X values. It is based on rf regressor from sklearn 1.6.1: + https://github.com/scikit-learn/scikit-learn/blob/99bf3d8e4eed5ba5db19a1869482a238b6223ffd/sklearn/ensemble/_forest.py#L1045 + + Parameters + ---------- + X: np.ndarray [#samples, #features] + input feature X + + Returns + ------- + preds: np.ndarray [#samples, #estimators,#output] + Predictions from all trees + + """ + # check_is_fitted(self) + # Check data + X = self._validate_X_predict(X) + + if X.ndim == 1: + X = X[None, :] + + # Assign chunk of trees to jobs + n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) + + # avoid storing the output of every estimator by summing them here + if self.n_outputs_ > 1: + preds = np.zeros((X.shape[0], self.n_estimators, self.n_outputs_), dtype=np.float64) + else: + preds = np.zeros((X.shape[0], self.n_estimators), dtype=np.float64) + + # Parallel loop + Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")( + delayed(estimator_predict)(e.predict, X, preds, tree_idx) for tree_idx, e in enumerate(self.estimators_) + ) + # This should be equivalent to the following implementation + + # preds_ = np.zeros([len(X), self.n_estimators]) + # for i, tree in enumerate(self.estimators_): + # preds_[:, i] = tree.predict(X) + # assert np.allclose(preds, preds_) + + return preds + + def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Predict the mean and variance of X. Here mean and variances are the empirical mean and variance values from the + prediction results of the other trees. + + Parameters + ---------- + X: np.ndarray [#samples, #hyperparameter] + Input data points to be testsed. + + Returns + ------- + means: np.ndarray [#samples, 1] + predicted mean. + vars: np.ndarray [#samples, 1] + predicted variance. + + """ + preds = self.all_trees_pred(X) + + means = preds.mean(axis=1) + vars = preds.var(axis=1) + + return means.reshape(-1, 1), vars.reshape(-1, 1) + + def predict_marginalized_over_instances_batch(self, X: np.ndarray, X_feat: np.ndarray, log_y: bool) -> np.ndarray: + """ + Collects the predictions for each tree in the forest for multiple configurations over a set of instances. + Each configuration vector is combined with all the instance feature vectors. Based on the response values over + all these feature vectors the mean is computed. In the case of log transformation the response values are + decompressed before averaging. + + Parameters + ---------- + X: np.ndarray [#samples, #hyperparameter] + Input data points. + X_feat: np.ndarray [#instance, #features], + Features (np.ndarray) of the instances (str). The features are incorporated into the X data, + on which the model is trained on. + log_y: bool, + if log_y is applied to the predictions. + + Returns + ------- + preds: np.ndarray [#samples, #estimators] + predictions for each sample and trees. Each element in preds corresponds to the mean response values for + the target estimator and configuration accross all the instances. + + """ + X = self._validate_X_predict(X, ensure_2d=False) + X_feat = self._validate_X_predict(X_feat, ensure_2d=False) + assert X.shape[-1] + X_feat.shape[-1] == self.n_features_in_ + + n_instances = len(X_feat) + + if X.ndim == 1: + X = X[None, :] + + # Assign chunk of trees to jobs + n_jobs, _, _ = _partition_estimators(self.n_estimators * n_instances, self.n_jobs) + + # avoid storing the output of every estimator by summing them here + if self.n_outputs_ > 1: + preds = np.zeros((X.shape[0], self.n_estimators, self.n_outputs_), dtype=np.float64) + else: + preds = np.zeros((X.shape[0], self.n_estimators), dtype=np.float64) + lock = threading.Lock() + # Parallel loop + Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")( + delayed(accumulate_predict_over_instances)(e.predict, X, x_feat, preds, tree_idx, n_instances, lock) + for (tree_idx, e), x_feat in product(enumerate(self.estimators_), X_feat) + ) + + return preds + + def _validate_X_predict(self, X: np.ndarray, ensure_2d: bool = True) -> np.ndarray: + """ + Validate X whenever one tries to predict, apply, predict_proba. + It is based on rf regressor from sklearn 1.6.1: + https://github.com/scikit-learn/scikit-learn/blob/99bf3d8e4eed5ba5db19a1869482a238b6223ffd/sklearn/ensemble/_forest.py#L629 + However, we add another parameter to allow the model to ignore feature checking. + This is applied for the cases where we have both hyperpameter features and instance features, the two features + will only be concatenated within each tree estimation functions. Hence, there is no need to check if their + individual number of features fit the number of features set in the RF model. + We will check if the number fo features fit the model afterwards for predict_marginalized_over_instances_batch + + Parameters + ---------- + X: np.ndarray + input features to be validated + ensure_2d: bool + if we check if the X's size match the fitted estimators' features + + """ + check_is_fitted(self) + if self.estimators_[0]._support_missing_values(X): + ensure_all_finite = "allow-nan" + else: + ensure_all_finite = True # type: ignore + X = validate_data( + self, + X, + dtype=DTYPE, + accept_sparse="csr", + reset=False, + ensure_all_finite=ensure_all_finite, + ensure_2d=ensure_2d, + ) + if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc): # type: ignore + raise ValueError("No support for np.int64 index based sparse matrices") + return X + + class RandomForest(AbstractRandomForest): """Random forest that takes instance features into account. @@ -22,8 +642,9 @@ class RandomForest(AbstractRandomForest): ---------- n_trees : int, defaults to `N_TREES` The number of trees in the random forest. - n_points_per_tree : int, defaults to -1 - Number of points per tree. If the value is smaller than 0, the number of samples will be used. + max_samples : int | float | None, defaults to None + Number of points per tree. If the value is None, the number of samples will be used. Otherwise, use + max_samples (if it is int) or max(round(n_samples * max_samples), 1) (if it is float value) ratio_features : float, defaults to 5.0 / 6.0 The ratio of features that are considered for splitting. min_samples_split : int, defaults to 3 @@ -32,9 +653,7 @@ class RandomForest(AbstractRandomForest): The minimum number of data points in a leaf. max_depth : int, defaults to 2**20 The maximum depth of a single tree. - eps_purity : float, defaults to 1e-8 - The minimum difference between two target values to be considered. - max_nodes : int, defaults to 2**20 + max_leaf_nodes : int, defaults to 2**20 The maximum total number of nodes in a tree. bootstrapping : bool, defaults to True Enables bootstrapping. @@ -52,19 +671,29 @@ class RandomForest(AbstractRandomForest): def __init__( self, configspace: ConfigurationSpace, - n_trees: int = N_TREES, - n_points_per_tree: int = -1, + max_samples: int | float | None = None, ratio_features: float = 5.0 / 6.0, - min_samples_split: int = 3, - min_samples_leaf: int = 3, - max_depth: int = 2**20, - eps_purity: float = 1e-8, - max_nodes: int = 2**20, - bootstrapping: bool = True, log_y: bool = False, instance_features: dict[str, list[int | float]] | None = None, pca_components: int | None = 7, seed: int = 0, + n_trees: int = N_TREES, + cross_trees_variance: bool = False, + criterion: str = "squared_error", + splitter: str = "random", + max_depth: int = 2**20, + min_samples_split: int = 3, + min_samples_leaf: int = 3, + min_weight_fraction_leaf: float = 0.0, + max_leaf_nodes: int = 2**20, + min_impurity_decrease: float = 1e-8, + bootstrapping: bool = True, + oob_score: bool = False, + n_jobs: int | None = -1, + verbose: int = 0, + warm_start: bool = False, + ccp_alpha: float = 0.0, + monotonic_cst: Iterable | None = None, ) -> None: super().__init__( configspace=configspace, @@ -75,64 +704,39 @@ def __init__( max_features = 0 if ratio_features > 1.0 else max(1, int(len(self._types) * ratio_features)) - self._rf_opts = regression.forest_opts() - self._rf_opts.num_trees = n_trees - self._rf_opts.do_bootstrapping = bootstrapping - self._rf_opts.tree_opts.max_features = max_features - self._rf_opts.tree_opts.min_samples_to_split = min_samples_split - self._rf_opts.tree_opts.min_samples_in_leaf = min_samples_leaf - self._rf_opts.tree_opts.max_depth = max_depth - self._rf_opts.tree_opts.epsilon_purity = eps_purity - self._rf_opts.tree_opts.max_num_nodes = max_nodes - self._rf_opts.compute_law_of_total_variance = False - self._rf: BinaryForest | None = None + self._rf: EPMRandomForest | None = None + self._rng = np.random.default_rng(seed=seed) # type: ignore + self._log_y = log_y - # Case to `int` incase we get an `np.integer` type - self._rng = regression.default_random_engine(int(seed)) - - self._n_trees = n_trees - self._n_points_per_tree = n_points_per_tree - self._ratio_features = ratio_features - self._min_samples_split = min_samples_split - self._min_samples_leaf = min_samples_leaf - self._max_depth = max_depth - self._eps_purity = eps_purity - self._max_nodes = max_nodes - self._bootstrapping = bootstrapping - self._rf = None - - # This list well be read out by save_iteration() in the solver - # self._hypers = [ - # n_trees, - # max_nodes, - # bootstrapping, - # n_points_per_tree, - # ratio_features, - # min_samples_split, - # min_samples_leaf, - # max_depth, - # eps_purity, - # self._seed, - # ] + self._rf_opts = { + "n_estimators": n_trees, + "cross_trees_variance": cross_trees_variance, + "criterion": criterion, + "splitter": splitter, + "max_depth": max_depth, + "min_samples_split": min_samples_split, + "min_samples_leaf": min_samples_leaf, + "min_weight_fraction_leaf": min_weight_fraction_leaf, + "max_leaf_nodes": max_leaf_nodes, + "min_impurity_decrease": min_impurity_decrease, + "bootstrap": bootstrapping, + "oob_score": oob_score, + "n_jobs": n_jobs, + "verbose": verbose, + "warm_start": warm_start, + "ccp_alpha": ccp_alpha, + "max_samples": max_samples, + "monotonic_cst": monotonic_cst, + "random_state": seed, + "max_features": max_features, + "log_y": log_y, + } @property def meta(self) -> dict[str, Any]: # noqa: D102 meta = super().meta - meta.update( - { - "n_trees": self._n_trees, - "n_points_per_tree": self._n_points_per_tree, - "ratio_features": self._ratio_features, - "min_samples_split": self._min_samples_split, - "min_samples_leaf": self._min_samples_leaf, - "max_depth": self._max_depth, - "eps_purity": self._eps_purity, - "max_nodes": self._max_nodes, - "bootstrapping": self._bootstrapping, - "pca_components": self._pca_components, - } - ) + meta.update(self._rf_opts) return meta @@ -140,51 +744,12 @@ def _train(self, X: np.ndarray, y: np.ndarray) -> RandomForest: X = self._impute_inactive(X) y = y.flatten() - # self.X = X - # self.y = y.flatten() - - if self._n_points_per_tree <= 0: - self._rf_opts.num_data_points_per_tree = X.shape[0] - else: - self._rf_opts.num_data_points_per_tree = self._n_points_per_tree - - self._rf = regression.binary_rss_forest() - self._rf.options = self._rf_opts + self._rf = EPMRandomForest(**self._rf_opts) # type: ignore - data = self._init_data_container(X, y) - self._rf.fit(data, rng=self._rng) + self._rf.fit(X, y) return self - def _init_data_container(self, X: np.ndarray, y: np.ndarray) -> DataContainer: - """Fills a pyrfr default data container s.t. the forest knows categoricals and bounds for continous data. - - Parameters - ---------- - X : np.ndarray [#samples, #hyperparameter + #features] - Input data points. - Y : np.ndarray [#samples, #objectives] - The corresponding target values. - - Returns - ------- - data : DataContainer - The filled data container that pyrfr can interpret. - """ - # Retrieve the types and the bounds from the ConfigSpace - data = regression.default_data_container(X.shape[1]) - - for i, (mn, mx) in enumerate(self._bounds): - if np.isnan(mx): - data.set_type_of_feature(i, mn) - else: - data.set_bounds_of_feature(i, mn, mx) - - for row_X, row_y in zip(X, y): - data.add_data_point(row_X, row_y) - - return data - def _predict( self, X: np.ndarray, @@ -201,45 +766,19 @@ def _predict( assert self._rf is not None X = self._impute_inactive(X) - - if self._log_y: - all_preds = [] - third_dimension = 0 - - # Gather data in a list of 2d arrays and get statistics about the required size of the 3d array - for row_X in X: - preds_per_tree = self._rf.all_leaf_values(row_X) - all_preds.append(preds_per_tree) - max_num_leaf_data = max(map(len, preds_per_tree)) - third_dimension = max(max_num_leaf_data, third_dimension) - - # Transform list of 2d arrays into a 3d array - preds_as_array = np.zeros((X.shape[0], self._rf_opts.num_trees, third_dimension)) * np.nan - for i, preds_per_tree in enumerate(all_preds): - for j, pred in enumerate(preds_per_tree): - preds_as_array[i, j, : len(pred)] = pred - - # Do all necessary computation with vectorized functions - preds_as_array = np.log(np.nanmean(np.exp(preds_as_array), axis=2) + VERY_SMALL_NUMBER) - - # Compute the mean and the variance across the different trees - means = preds_as_array.mean(axis=1) - vars_ = preds_as_array.var(axis=1) - else: - means, vars_ = [], [] - for row_X in X: - mean_, var = self._rf.predict_mean_var(row_X) - means.append(mean_) - vars_.append(var) - - means = np.array(means) - vars_ = np.array(vars_) - + means, vars_ = self._rf.predict(X) return means.reshape((-1, 1)), vars_.reshape((-1, 1)) def predict_marginalized(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Predicts mean and variance marginalized over all instances. + Under the hood: Collects the predictions for each tree in the forest + for multiple configurations over a set of instances. Each configuration + vector is combined with all the instance feature vectors. Based on the + response values over all these feature vectors the mean is computed. + In the case of log transformation the response values are decompressed + before averaging. + Note ---- The method is random forest specific and follows the SMAC2 implementation. It requires @@ -277,7 +816,7 @@ def predict_marginalized(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]: assert self._rf is not None X = self._impute_inactive(X) - X_feat = list(self._instance_features.values()) + X_feat = np.asarray(list(self._instance_features.values())) dat_ = self._rf.predict_marginalized_over_instances_batch(X, X_feat, self._log_y) dat_ = np.array(dat_) diff --git a/tests/test_model/test_rf.py b/tests/test_model/test_rf.py index 65e6478ff..789e8fb40 100644 --- a/tests/test_model/test_rf.py +++ b/tests/test_model/test_rf.py @@ -9,6 +9,7 @@ UniformIntegerHyperparameter, ) +from smac import constants from smac.model.random_forest.random_forest import RandomForest from smac.utils.configspace import convert_configurations_to_array @@ -125,6 +126,24 @@ def test_predict_marginalized(): assert means.shape == (20, 1) assert variances.shape == (20, 1) + # now we need tp ensure that the prediction results is the same as we do that prediction individually + n_estimators = model._rf_opts["n_estimators"] + n_features = len(F) + n_data = len(X) + + all_features = np.asarray(list(F.values())) + all_preds = np.empty( + [n_data, n_estimators, n_features] + ) + for i_tree in range(n_estimators): + for i_feat in range(n_features): + for i_data in range(n_data): + x_input = np.concatenate([X[i_data], all_features[i_feat]])[None, :] + all_preds[[i_data], i_tree, i_feat] = model._rf.estimators_[i_tree].predict(x_input) + pred_marginalized_over_instance = np.mean(all_preds, -1) + assert np.allclose(np.mean(pred_marginalized_over_instance, axis=-1, keepdims=True), means) + assert np.allclose(np.var(pred_marginalized_over_instance, axis=-1, keepdims=True), variances) + def test_predict_marginalized_mocked(): rs = np.random.RandomState(1) @@ -288,3 +307,72 @@ def test_impute_inactive_hyperparameters(): elif line[0] == 2: assert line[1] == 2 assert line[2] == -1 + + +def test_rf_with_log_y(): + X = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + [1.0, 1.0, 1.0], + ], + dtype=np.float64, + ) + y = np.array([[0.1], [0.2], [9], [9.2], [100.0], [100.2], [109.0], [109.2]], dtype=np.float64) + model1 = RandomForest( + configspace=_get_cs(3), + instance_features=None, + seed=12345, + ratio_features=1.0, + log_y=True + ) + model1.train(np.vstack((X, X, X, X, X, X, X, X)), np.vstack((y, y, y, y, y, y, y, y))) + X_test = np.random.rand(10, 3) + + mean1, var1 = model1.predict(X_test) + #for y_i, y_hat_i in zip(y.reshape((1, -1)).flatten(), y_hat.reshape((1, -1)).flatten()): + # assert pytest.approx(y_i, 0.1) == y_hat_i + + # The following should be equivalent to the log_y version + + model2 = RandomForest( + configspace=_get_cs(3), + instance_features=None, + seed=12345, + ratio_features=1.0, + log_y=False + ) + all_preds = [] + third_dimension = 0 + + model2.train(np.vstack((X, X, X, X, X, X, X, X)), np.vstack((y, y, y, y, y, y, y, y))) + + # Gather data in a list of 2d arrays and get statistics about the required size of the 3d array + for row_X in X_test: + preds_per_tree = [estimator.predict(row_X[None, :]) for estimator in model2._rf.estimators_] + #preds_per_tree = model_no_logy._rf.all_leaf_values(row_X) + all_preds.append(preds_per_tree) + max_num_leaf_data = max(map(len, preds_per_tree)) + third_dimension = max(max_num_leaf_data, third_dimension) + + # Transform list of 2d arrays into a 3d array + preds_as_array = np.zeros((X_test.shape[0], model2._rf_opts['n_estimators'], third_dimension)) * np.nan + for i, preds_per_tree in enumerate(all_preds): + for j, pred in enumerate(preds_per_tree): + preds_as_array[i, j, : len(pred)] = pred + + # Do all necessary computation with vectorized functions + preds_as_array = np.log(np.nanmean(np.exp(preds_as_array), axis=2) + constants.VERY_SMALL_NUMBER) + + # Compute the mean and the variance across the different trees + mean2 = preds_as_array.mean(axis=1, keepdims=True) + var2 = preds_as_array.var(axis=1, keepdims=True) + + assert np.allclose(mean1, mean2) + assert np.allclose(var1, var2) +