|
43 | 43 | from cuml.dask.ensemble import RandomForestRegressor as cuRFR_mg
|
44 | 44 | from cuml.dask.common import utils as dask_utils
|
45 | 45 |
|
| 46 | +from cuml.ensemble import RandomForestClassifier as cuRFC_sg |
| 47 | +from cuml.ensemble import RandomForestRegressor as cuRFR_sg |
| 48 | + |
46 | 49 | from dask.array import from_array
|
47 | 50 | from sklearn.datasets import make_regression, make_classification
|
48 | 51 | from sklearn.model_selection import train_test_split
|
@@ -436,6 +439,47 @@ def predict_with_json_rf_regressor(rf, x):
|
436 | 439 | np.testing.assert_almost_equal(pred, expected_pred, decimal=6)
|
437 | 440 |
|
438 | 441 |
|
| 442 | +@pytest.mark.parametrize('estimator_type', ['regression', 'classification']) |
| 443 | +def test_rf_get_combined_model_right_aftter_fit(client, estimator_type): |
| 444 | + max_depth = 3 |
| 445 | + n_estimators = 5 |
| 446 | + X, y = make_classification() |
| 447 | + X = X.astype(np.float32) |
| 448 | + if estimator_type == 'classification': |
| 449 | + cu_rf_mg = cuRFC_mg( |
| 450 | + max_features=1.0, |
| 451 | + max_samples=1.0, |
| 452 | + n_bins=16, |
| 453 | + n_streams=1, |
| 454 | + n_estimators=n_estimators, |
| 455 | + max_leaves=-1, |
| 456 | + max_depth=max_depth |
| 457 | + ) |
| 458 | + y = y.astype(np.int32) |
| 459 | + elif estimator_type == 'regression': |
| 460 | + cu_rf_mg = cuRFR_mg( |
| 461 | + max_features=1.0, |
| 462 | + max_samples=1.0, |
| 463 | + n_bins=16, |
| 464 | + n_streams=1, |
| 465 | + n_estimators=n_estimators, |
| 466 | + max_leaves=-1, |
| 467 | + max_depth=max_depth |
| 468 | + ) |
| 469 | + y = y.astype(np.float32) |
| 470 | + else: |
| 471 | + assert False |
| 472 | + X_dask, y_dask = _prep_training_data(client, X, y, partitions_per_worker=2) |
| 473 | + cu_rf_mg.fit(X_dask, y_dask) |
| 474 | + single_gpu_model = cu_rf_mg.get_combined_model() |
| 475 | + if estimator_type == 'classification': |
| 476 | + assert isinstance(single_gpu_model, cuRFC_sg) |
| 477 | + elif estimator_type == 'regression': |
| 478 | + assert isinstance(single_gpu_model, cuRFR_sg) |
| 479 | + else: |
| 480 | + assert False |
| 481 | + |
| 482 | + |
439 | 483 | @pytest.mark.parametrize('n_estimators', [5, 10, 20])
|
440 | 484 | @pytest.mark.parametrize('detailed_text', [True, False])
|
441 | 485 | def test_rf_get_text(client, n_estimators, detailed_text):
|
|
0 commit comments