Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] support 'raw_score' in predict() #3793

Closed
jameslamb opened this issue Jan 19, 2021 · 11 comments
Closed

[dask] support 'raw_score' in predict() #3793

jameslamb opened this issue Jan 19, 2021 · 11 comments

Comments

@jameslamb
Copy link
Collaborator

Summary

See #3774 (comment) for background.

To close this issue, add tests to https://github.com/microsoft/LightGBM/blob/master/tests/python_package_test/test_dask.py confirming that .predict(X, raw_score=True) works for DaskLGBMClassifier and DaskLGBMRegressor . For DaskLGBMClassifier, this should work with both .predict() and .predict_proba().

Motivation

Adding this feature would allow users to get raw predictions from LightGBM instead of converted output, which might be used as inputs to another model or for other exploratory purposes.

This would bring the Dask interface closer to parity with lightgbm.sklearn.

References

When I first attempted this in #3774, I hit some issues related to how LightGBM uses dask.DataFrame.map_partitions() and dask.Array.map_blocks(). These function allow you to apply a function to each partition of a distributed data collection, like X.map(some_func). In some situations, they require you to provide additional information about the shape and data types of the return from some_func().

You can learn more about this in the map_partitions() docs: https://docs.dask.org/en/latest/dataframe-api.html#dask.dataframe.DataFrame.map_partitions

For testing this feature, you may find https://github.com/jameslamb/lightgbm-dask-testing useful.

@jameslamb
Copy link
Collaborator Author

Closing this in favor of #2302, where we store all feature requests. Anyone is welcome to contribute this feature. Leave a comment on this issue if you'd like to work on it.

@jmoralez
Copy link
Collaborator

I'd like to take this as well, James.

@jameslamb
Copy link
Collaborator Author

sure, thanks! It's possible that this already works and that all we need are tests. Look to test_sklearn.py for inspiration

@jameslamb jameslamb reopened this Feb 25, 2021
@jmoralez
Copy link
Collaborator

jmoralez commented Feb 25, 2021

I've repeated the part in the tests that sets up the model factory and data from task and output a couple of times now. I want to extract that to a function like _model_factory_and_data_from_task_and_output. Should I open a different PR for that kind of refactor?

@jameslamb
Copy link
Collaborator Author

I'd prefer to keep the data stuff as-is for now, please. There are already several layers of indirection in the data loading for those tests and I'd prefer not to add more right now.

You could reduce duplication in the model_factory stuff simply with a dictionary.

task_to_dask_factory = {
    "regression": lgb.DaskLGBMRegressor,
    "classification": lgb.DaskLGBMClassifier,
    "ranking": lgb.DaskLGBMRanker
}

...
def test_whatever(task):
    dask_model_factory = task_to_dask_factory(task)

I'd welcome a PR that does that.

@jmoralez
Copy link
Collaborator

Should this be a test that it just doesn't break? Like just calling model.predict and model.predict_proba? I looked at test_sklearn but there they're testing equality in the predictions from the sklearn api vs the training api.

I'd prefer to keep the data stuff as-is for now, please.

Are you sure? I really want to do it haha, it removes like 90 lines in the file.

@jameslamb
Copy link
Collaborator Author

jameslamb commented Feb 25, 2021

Should this be a test that it just doesn't break? Like just calling model.predict and model.predict_proba? I looked at test_sklearn but there they're testing equality in the predictions from the sklearn api vs the training api.

No, the test should check that you're getting the raw scores.

Something similar to this:

reg = lgb.DaskLGBMRegressor(
    n_estimators=1,
    num_leaves=2,
    min_sum_hessian=0
)
reg.fit(X, y)
preds = reg.predict(X, raw_score=True).compute()

leaf_values = reg.booster_.trees_to_dataframe()['value']
leaf_weights =  reg.booster_.trees_to_dataframe()['weight']
sum_weights = leaf_weights[1] + leaf_weights[2]

expected_mean = leaf_values[1] * (leaf_weights[1]/sum_weights) + leaf_values[2] * (leaf_weights[2]/sum_weights)
assert np.mean(preds) == expected_mean

UPDATE: forgot to add min_sum_hessian when I first posted this.

Are you sure? I really want to do it haha, it removes like 90 lines in the file.

Yes, please don't touch it. It isn't important that each of the calls to _create_data() uses identical parameters, so condensing them to one function would be adding a layer of indirection for a consistency benefit that isn't relevant to the correctness or speed of the tests.

@jmoralez
Copy link
Collaborator

How about for regression:

model = lgb.DaskLGBMRegressor(objective='poisson')
model.fit(dX, dy)
regular_predictions = model.predict(dX).compute()
raw_predictions = model.predict(dX, raw_score=True).compute()
np.testing.assert_equal(np.exp(raw_predictions), regular_predictions)

I believe for classification would be applying sigmoid. For ranking I have to check it out.

@jameslamb
Copy link
Collaborator Author

If you follow my suggestion, you don't have to think about what the objective is. Please try that before exploring other options that require using non-default objectives.

@jmoralez
Copy link
Collaborator

What's the preferred way to compare scalar floats? I have an array of [-0.2, 0.2, -0.2, 0.2, ...] which mean should be zero but is something e-17 and both the np.testing.assert_equal and np.testing.assert_close fail, should I use np.testing.assert_almost_equal?

@jameslamb
Copy link
Collaborator Author

abs(x) < 1e-14 or something would be fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants