Skip to content

Commit df67553

Browse files
authored
Allow saving Dask RandomForest models immediately after training (fixes #3331) (#3388)
This attempts to fix #3331. See that issue for a lot more details. Today, `.get_combined_model()` for the Dask RandomForest model objects returns `None` if it's called immediately after training. That pattern is recommended in ["Distributed Model Pickling"](https://docs.rapids.ai/api/cuml/stable/pickling_cuml_models.html#Distributed-Model-Pickling). Without this support, there is not a way to save a Dask RandomForest model using only public methods / attributes on those classes. Per #3331 (comment), this PR proposes populating the internal model object whenever `get_combined_model()` is called. ## Notes for Reviewers * I have not tested this locally. I spent about 3 hours trying to build `cuml` from source following https://github.com/rapidsai/cuml/blob/main/BUILD.md, and was not successful. If there is a containerized setup for developing `cuml`, I'd greatly appreciate it and would be happy to try it out. I've added a unit test for this change, so I hope that will be enough to confirm that this works and that CI will catch any mistakes I've made. Thanks for your time and consideration. Authors: - James Lamb (@jameslamb) - John Zedlewski (@JohnZed) Approvers: - John Zedlewski (@JohnZed) URL: #3388
1 parent da87257 commit df67553

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

python/cuml/dask/ensemble/base.py

+31
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import numpy as np
2020
import warnings
2121

22+
from collections.abc import Iterable
23+
from dask.distributed import Future
24+
2225
from cuml.dask.common.input_utils import DistributedDataHandler, \
2326
concatenate
2427
from cuml.dask.common.utils import get_client, wait_and_raise_from_futures
@@ -257,6 +260,34 @@ def _get_json(self):
257260
combined_dump.extend(obj)
258261
return json.dumps(combined_dump)
259262

263+
def get_combined_model(self):
264+
"""
265+
Return single-GPU model for serialization.
266+
267+
Returns
268+
-------
269+
270+
model : Trained single-GPU model or None if the model has not
271+
yet been trained.
272+
"""
273+
274+
# set internal model if it hasn't been accessed before
275+
if self._get_internal_model() is None:
276+
self._set_internal_model(self._concat_treelite_models())
277+
278+
internal_model = self._check_internal_model(self._get_internal_model())
279+
280+
if isinstance(self.internal_model, Iterable):
281+
# This function needs to return a single instance of cuml.Base,
282+
# even if the class is just a composite.
283+
raise ValueError("Expected a single instance of cuml.Base "
284+
"but got %s instead." % type(self.internal_model))
285+
286+
elif isinstance(self.internal_model, Future):
287+
internal_model = self.internal_model.result()
288+
289+
return internal_model
290+
260291

261292
def _func_fit(model, input_data, convert_dtype):
262293
X = concatenate([item[0] for item in input_data])

python/cuml/test/dask/test_random_forest.py

+44
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
from cuml.dask.ensemble import RandomForestRegressor as cuRFR_mg
4444
from cuml.dask.common import utils as dask_utils
4545

46+
from cuml.ensemble import RandomForestClassifier as cuRFC_sg
47+
from cuml.ensemble import RandomForestRegressor as cuRFR_sg
48+
4649
from dask.array import from_array
4750
from sklearn.datasets import make_regression, make_classification
4851
from sklearn.model_selection import train_test_split
@@ -436,6 +439,47 @@ def predict_with_json_rf_regressor(rf, x):
436439
np.testing.assert_almost_equal(pred, expected_pred, decimal=6)
437440

438441

442+
@pytest.mark.parametrize('estimator_type', ['regression', 'classification'])
443+
def test_rf_get_combined_model_right_aftter_fit(client, estimator_type):
444+
max_depth = 3
445+
n_estimators = 5
446+
X, y = make_classification()
447+
X = X.astype(np.float32)
448+
if estimator_type == 'classification':
449+
cu_rf_mg = cuRFC_mg(
450+
max_features=1.0,
451+
max_samples=1.0,
452+
n_bins=16,
453+
n_streams=1,
454+
n_estimators=n_estimators,
455+
max_leaves=-1,
456+
max_depth=max_depth
457+
)
458+
y = y.astype(np.int32)
459+
elif estimator_type == 'regression':
460+
cu_rf_mg = cuRFR_mg(
461+
max_features=1.0,
462+
max_samples=1.0,
463+
n_bins=16,
464+
n_streams=1,
465+
n_estimators=n_estimators,
466+
max_leaves=-1,
467+
max_depth=max_depth
468+
)
469+
y = y.astype(np.float32)
470+
else:
471+
assert False
472+
X_dask, y_dask = _prep_training_data(client, X, y, partitions_per_worker=2)
473+
cu_rf_mg.fit(X_dask, y_dask)
474+
single_gpu_model = cu_rf_mg.get_combined_model()
475+
if estimator_type == 'classification':
476+
assert isinstance(single_gpu_model, cuRFC_sg)
477+
elif estimator_type == 'regression':
478+
assert isinstance(single_gpu_model, cuRFR_sg)
479+
else:
480+
assert False
481+
482+
439483
@pytest.mark.parametrize('n_estimators', [5, 10, 20])
440484
@pytest.mark.parametrize('detailed_text', [True, False])
441485
def test_rf_get_text(client, n_estimators, detailed_text):

0 commit comments

Comments
 (0)