From b6deb9a857edc71da5f5f17a295043b83606da04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 27 Jun 2022 18:29:35 -0500 Subject: [PATCH] [python-package] allow custom weighing in fobj for scikit-learn API (closes #5027) (#5211) * allow custom weighing in sklearn api * add suggestions from review Co-authored-by: Nikita Titov --- python-package/lightgbm/sklearn.py | 34 +++++++++------- tests/python_package_test/test_engine.py | 25 +++++++++--- tests/python_package_test/test_sklearn.py | 48 +++++++++++++++++++++-- tests/python_package_test/utils.py | 6 ++- 4 files changed, 89 insertions(+), 24 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 682446455838..12df84d95d67 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -25,6 +25,10 @@ [np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray] ], + Callable[ + [np.ndarray, np.ndarray, np.ndarray, np.ndarray], + Tuple[np.ndarray, np.ndarray] + ], ] _LGBM_ScikitCustomEvalFunction = Union[ Callable[ @@ -54,7 +58,10 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction): Parameters ---------- func : callable - Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group)`` + Expects a callable with following signatures: + ``func(y_true, y_pred)``, + ``func(y_true, y_pred, weight)`` + or ``func(y_true, y_pred, weight, group)`` and returns (grad, hess): y_true : numpy 1-D array of shape = [n_samples] @@ -63,6 +70,8 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction): The predicted values. Predicted values are returned before any transformation, e.g. they are raw margin instead of probability of positive class for binary task. + weight : numpy 1-D array of shape = [n_samples] + The weight of samples. Weights should be non-negative. group : numpy 1-D array Group/query data. Only used in the learning-to-rank task. @@ -107,19 +116,11 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. if argc == 2: grad, hess = self.func(labels, preds) elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_group()) + grad, hess = self.func(labels, preds, dataset.get_weight()) + elif argc == 4: + grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) else: - raise TypeError(f"Self-defined objective function should have 2 or 3 arguments, got {argc}") - """weighted for objective""" - weight = dataset.get_weight() - if weight is not None: - if grad.ndim == 2: # multi-class - num_data = grad.shape[0] - if weight.size != num_data: - raise ValueError("grad and hess should be of shape [n_samples, n_classes]") - weight = weight.reshape(num_data, 1) - grad *= weight - hess *= weight + raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") return grad, hess @@ -456,8 +457,9 @@ def __init__( ---- A custom objective function can be provided for the ``objective`` parameter. In this case, it should have the signature - ``objective(y_true, y_pred) -> grad, hess`` or - ``objective(y_true, y_pred, group) -> grad, hess``: + ``objective(y_true, y_pred) -> grad, hess``, + ``objective(y_true, y_pred, weight) -> grad, hess`` + or ``objective(y_true, y_pred, weight, group) -> grad, hess``: y_true : numpy 1-D array of shape = [n_samples] The target values. @@ -465,6 +467,8 @@ def __init__( The predicted values. Predicted values are returned before any transformation, e.g. they are raw margin instead of probability of positive class for binary task. + weight : numpy 1-D array of shape = [n_samples] + The weight of samples. Weights should be non-negative. group : numpy 1-D array Group/query data. Only used in the learning-to-rank task. diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e53bb6b0e594..017e46788c55 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2433,14 +2433,20 @@ def test_default_objective_and_metric(): assert len(evals_result['valid_0']['l2']) == 5 -def test_multiclass_custom_objective(): +@pytest.mark.parametrize('use_weight', [True, False]) +def test_multiclass_custom_objective(use_weight): def custom_obj(y_pred, ds): y_true = ds.get_label() - return sklearn_multiclass_custom_objective(y_true, y_pred) + weight = ds.get_weight() + grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight) + return grad, hess centers = [[-4, -4], [4, 4], [-4, 4]] X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42) + weight = np.full_like(y, 2) ds = lgb.Dataset(X, y) + if use_weight: + ds.set_weight(weight) params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7} builtin_obj_bst = lgb.train(params, ds, num_boost_round=10) builtin_obj_preds = builtin_obj_bst.predict(X) @@ -2452,16 +2458,25 @@ def custom_obj(y_pred, ds): np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01) -def test_multiclass_custom_eval(): +@pytest.mark.parametrize('use_weight', [True, False]) +def test_multiclass_custom_eval(use_weight): def custom_eval(y_pred, ds): y_true = ds.get_label() - return 'custom_logloss', log_loss(y_true, y_pred), False + weight = ds.get_weight() # weight is None when not set + loss = log_loss(y_true, y_pred, sample_weight=weight) + return 'custom_logloss', loss, False centers = [[-4, -4], [4, 4], [-4, 4]] X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42) - X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0) + weight = np.full_like(y, 2) + X_train, X_valid, y_train, y_valid, weight_train, weight_valid = train_test_split( + X, y, weight, test_size=0.2, random_state=0 + ) train_ds = lgb.Dataset(X_train, y_train) valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds) + if use_weight: + train_ds.set_weight(weight_train) + valid_ds.set_weight(weight_valid) params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7} eval_result = {} bst = lgb.train( diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 2fdd31c23be1..4fe65cd8645a 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -2,6 +2,7 @@ import itertools import math import re +from functools import partial from os import getenv from pathlib import Path @@ -1285,16 +1286,18 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task np.testing.assert_array_equal(preds_1d, preds_2d) -def test_multiclass_custom_objective(): +@pytest.mark.parametrize('use_weight', [True, False]) +def test_multiclass_custom_objective(use_weight): centers = [[-4, -4], [4, 4], [-4, 4]] X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42) + weight = np.full_like(y, 2) if use_weight else None params = {'n_estimators': 10, 'num_leaves': 7} builtin_obj_model = lgb.LGBMClassifier(**params) - builtin_obj_model.fit(X, y) + builtin_obj_model.fit(X, y, sample_weight=weight) builtin_obj_preds = builtin_obj_model.predict_proba(X) custom_obj_model = lgb.LGBMClassifier(objective=sklearn_multiclass_custom_objective, **params) - custom_obj_model.fit(X, y) + custom_obj_model.fit(X, y, sample_weight=weight) custom_obj_preds = softmax(custom_obj_model.predict(X, raw_score=True)) np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01) @@ -1302,6 +1305,45 @@ def test_multiclass_custom_objective(): assert callable(custom_obj_model.objective_) +@pytest.mark.parametrize('use_weight', [True, False]) +def test_multiclass_custom_eval(use_weight): + def custom_eval(y_true, y_pred, weight): + loss = log_loss(y_true, y_pred, sample_weight=weight) + return 'custom_logloss', loss, False + + centers = [[-4, -4], [4, 4], [-4, 4]] + X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42) + train_test_split_func = partial(train_test_split, test_size=0.2, random_state=0) + X_train, X_valid, y_train, y_valid = train_test_split_func(X, y) + if use_weight: + weight = np.full_like(y, 2) + weight_train, weight_valid = train_test_split_func(weight) + else: + weight_train = None + weight_valid = None + params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7} + model = lgb.LGBMClassifier(**params) + model.fit( + X_train, + y_train, + sample_weight=weight_train, + eval_set=[(X_train, y_train), (X_valid, y_valid)], + eval_names=['train', 'valid'], + eval_sample_weight=[weight_train, weight_valid], + eval_metric=custom_eval, + ) + eval_result = model.evals_result_ + train_ds = (X_train, y_train, weight_train) + valid_ds = (X_valid, y_valid, weight_valid) + for key, (X, y_true, weight) in zip(['train', 'valid'], [train_ds, valid_ds]): + np.testing.assert_allclose( + eval_result[key]['multi_logloss'], eval_result[key]['custom_logloss'] + ) + y_pred = model.predict_proba(X) + _, metric_value, _ = custom_eval(y_true, y_pred, weight) + np.testing.assert_allclose(metric_value, eval_result[key]['custom_logloss'][-1]) + + def test_negative_n_jobs(tmp_path): n_threads = joblib.cpu_count() if n_threads <= 1: diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index 472343091566..fc142ede9fe7 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -140,7 +140,7 @@ def logistic_sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) -def sklearn_multiclass_custom_objective(y_true, y_pred): +def sklearn_multiclass_custom_objective(y_true, y_pred, weight=None): num_rows, num_class = y_pred.shape prob = softmax(y_pred) grad_update = np.zeros_like(prob) @@ -148,6 +148,10 @@ def sklearn_multiclass_custom_objective(y_true, y_pred): grad = prob + grad_update factor = num_class / (num_class - 1) hess = factor * prob * (1 - prob) + if weight is not None: + weight2d = weight.reshape(-1, 1) + grad *= weight2d + hess *= weight2d return grad, hess