Skip to content

Commit

Permalink
Update tests/python_package_test/test_dask.py
Browse files Browse the repository at this point in the history
Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
neNasko1 and jameslamb authored Oct 11, 2024
1 parent 938cb63 commit bed5ded
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,11 +1545,7 @@ def test_predict_with_raw_score(task, output, cluster):
@pytest.mark.parametrize("use_init_score", [False, True])
def test_predict_stump(output, use_init_score, cluster, rng):
with Client(cluster) as client:
task = "binary-classification"
n_samples = 1_000
_, _, _, _, dX, dy, _, dg = _create_data(objective=task, n_samples=n_samples, output=output)

model_factory = task_to_dask_factory[task]
_, _, _, _, dX, dy, _, _ = _create_data(objective="binary-classification", n_samples=1_000, output=output)

params = {"objective": "binary", "n_estimators": 5, "min_data_in_leaf": n_samples}

Expand All @@ -1560,7 +1556,7 @@ def test_predict_stump(output, use_init_score, cluster, rng):
else:
init_scores = dy.map_blocks(lambda x: rng.uniform(size=x.size))

model = model_factory(client=client, **params)
model = lgb.DaskLGBMClassifier(client=client, **params)
model.fit(dX, dy, group=dg, init_score=init_scores)
preds_1 = model.predict(dX, raw_score=True, num_iteration=1).compute()
preds_all = model.predict(dX, raw_score=True).compute()
Expand Down

0 comments on commit bed5ded

Please sign in to comment.