Skip to content

Commit

Permalink
[REVIEW] Allow saving Dask RandomForest models immediately after trai…
Browse files Browse the repository at this point in the history
…ning (fixes #3331)
  • Loading branch information
jameslamb committed Jan 20, 2021
1 parent d72c54a commit eb67bd1
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/cuml/dask/ensemble/randomforestclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,23 @@ def get_summary_text(self):
"""
return self._get_summary_text()

def get_combined_model(self):
"""
Return single-GPU model for serialization.
Returns
-------
model : Trained single-GPU model or None if the model has not
yet been trained.
"""

# set internal model if it hasn't been accessed before
if self._get_internal_model() is None:
self._set_internal_model(self._concat_treelite_models())

return BaseEstimator.get_combined_model(self)

def get_detailed_text(self):
"""
Obtain the detailed information for the random forest model, as text
Expand Down
17 changes: 17 additions & 0 deletions python/cuml/dask/ensemble/randomforestregressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,23 @@ def get_summary_text(self):
"""
return self._get_summary_text()

def get_combined_model(self):
"""
Return single-GPU model for serialization.
Returns
-------
model : Trained single-GPU model or None if the model has not
yet been trained.
"""

# set internal model if it hasn't been accessed before
if self._get_internal_model() is None:
self._set_internal_model(self._concat_treelite_models())

return BaseEstimator.get_combined_model(self)

def get_detailed_text(self):
"""
Obtain the detailed information for the random forest model, as text
Expand Down
58 changes: 58 additions & 0 deletions python/cuml/test/dask/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from cuml.dask.ensemble import RandomForestRegressor as cuRFR_mg
from cuml.dask.common import utils as dask_utils

from cuml.ensemble import RandomForestClassifier as cuRFC_sg
from cuml.ensemble import RandomForestRegressor as cuRFR_sg

from dask.array import from_array
from sklearn.datasets import make_regression, make_classification
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -436,6 +439,61 @@ def predict_with_json_rf_regressor(rf, x):
np.testing.assert_almost_equal(pred, expected_pred, decimal=6)


@pytest.mark.parametrize('estimator_type', ['regression', 'classification'])
def test_rf_get_combined_model_right_aftter_fit(client, estimator_type):
max_depth = 3
n_estimators = 5
X, y = make_classification(
n_samples=350,
n_features=20,
n_clusters_per_class=1,
n_informative=10,
random_state=123,
n_classes=2
)
X = X.astype(np.float32)
if estimator_type == 'classification':
cu_rf_mg = cuRFC_mg(
max_features=1.0,
max_samples=1.0,
n_bins=16,
split_algo=0,
split_criterion=0,
min_samples_leaf=2,
seed=23707,
n_streams=1,
n_estimators=n_estimators,
max_leaves=-1,
max_depth=max_depth
)
y = y.astype(np.int32)
elif estimator_type == 'regression':
cu_rf_mg = cuRFR_mg(
max_features=1.0,
max_samples=1.0,
n_bins=16,
split_algo=0,
min_samples_leaf=2,
seed=23707,
n_streams=1,
n_estimators=n_estimators,
max_leaves=-1,
max_depth=max_depth
)
y = y.astype(np.float32)
else:
assert False
X_dask, y_dask = _prep_training_data(client, X, y, partitions_per_worker=2)
cu_rf_mg.fit(X_dask, y_dask)
single_gpu_model = cu_rf_mg.get_combined_model()
if estimator_type == 'classification':
assert isinstance(single_gpu_model, cuRFC_sg)
elif estimator_type == 'regression':
assert isinstance(single_gpu_model, cuRFR_sg)
else:
assert False


@pytest.mark.parametrize('n_estimators', [5, 10, 20])
@pytest.mark.parametrize('detailed_text', [True, False])
def test_rf_get_text(client, n_estimators, detailed_text):
Expand Down

0 comments on commit eb67bd1

Please sign in to comment.