Skip to content

Commit

Permalink
[ENH] online update capability for probabilistic regressors (#462)
Browse files Browse the repository at this point in the history
Adds framework support for online update capability for probabilistic
regressors, and a simple composite strategy that refits on all data, for
testing the framework. Closes #463

Contains:

* extension of the regressor and survival base class with a potential
`update` / `_update` method for batch updates
* addition of a tag `capability:online` for respective estimators
* addition of a composite `OnlineRefit` that adds the
`capability:online` tag and refits the regressor upon all data seen so
far. This is a separate estimator to avoid that all estimators remember
(and clutter `self`) with the data
* a similar composite `OnlineDontRefit` that turns off online capability
* a specific test case for online updates, in `TestAllRegressors`
  • Loading branch information
fkiraly authored Sep 27, 2024
1 parent 8782435 commit ba2aae5
Show file tree
Hide file tree
Showing 10 changed files with 520 additions and 1 deletion.
12 changes: 12 additions & 0 deletions docs/source/api_reference/regression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ Model selection and tuning

evaluate

Online learning
---------------

.. currentmodule:: skpro.regression.online

.. autosummary::
:toctree: auto_generated/
:template: class.rst

OnlineRefit
OnlineDontRefit

Reduction - adding ``predict_proba``
------------------------------------

Expand Down
6 changes: 6 additions & 0 deletions skpro/registry/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@
"bool",
"whether estimator supports missing values",
),
(
"capability:online",
"regressor_proba",
"bool",
"whether estimator supports online updates via update",
),
(
"X_inner_mtype",
"regressor_proba",
Expand Down
69 changes: 69 additions & 0 deletions skpro/regression/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class BaseProbaRegressor(BaseEstimator):
"capability:survival": False,
"capability:multioutput": False,
"capability:missing": True,
"capability:online": False,
"X_inner_mtype": "pd_DataFrame_Table",
"y_inner_mtype": "pd_DataFrame_Table",
"C_inner_mtype": "pd_DataFrame_Table",
Expand Down Expand Up @@ -136,6 +137,74 @@ def _fit(self, X, y, C=None):
"""
raise NotImplementedError

def update(self, X, y, C=None):
"""Update regressor with a new batch of training data.
Only estimators with the ``capability:online`` tag (value ``True``)
provide this method, otherwise the method ignores the call and
discards the data passed.
State required:
Requires state to be "fitted".
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pd.DataFrame, must be same length as X
labels to fit regressor to
C : ignored, optional (default=None)
censoring information for survival analysis
All probabilistic regressors assume data to be uncensored
Returns
-------
self : reference to self
"""
capa_online = self.get_tag("capability:online")
capa_surv = self.get_tag("capability:survival")

if not capa_online:
return self

check_ret = self._check_X_y(X, y, C, return_metadata=True)

# get inner X, y, C
X_inner = check_ret["X_inner"]
y_inner = check_ret["y_inner"]
if capa_surv:
C_inner = check_ret["C_inner"]

if not capa_surv:
return self._update(X_inner, y_inner)
else:
return self._update(X_inner, y_inner, C=C_inner)

def _update(self, X, y, C=None):
"""Update regressor with a new batch of training data.
State required:
Requires state to be "fitted".
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pandas DataFrame, must be same length as X
labels to fit regressor to
Returns
-------
self : reference to self
"""
raise NotImplementedError

def predict(self, X):
"""Predict labels for data from features.
Expand Down
32 changes: 32 additions & 0 deletions skpro/regression/base/_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ def _fit(self, X, y, C=None):
estimator.fit(X=X, y=y, C=C)
return self

def _update(self, X, y, C=None):
"""Update regressor with a new batch of training data.
State required:
Requires state to be "fitted" = self.is_fitted=True
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pd.DataFrame, must be same length as X
labels to fit regressor to
C : pd.DataFrame, optional (default=None)
censoring information for survival analysis,
should have same column name as y, same length as X and y
should have entries 0 and 1 (float or int)
0 = uncensored, 1 = (right) censored
if None, all observations are assumed to be uncensored
Can be passed to any probabilistic regressor,
but is ignored if capability:survival tag is False.
Returns
-------
self : reference to self
"""
estimator = self._get_delegate()
estimator.update(X=X, y=y, C=C)
return self

def _predict(self, X):
"""Predict labels for data from features.
Expand Down
38 changes: 37 additions & 1 deletion skpro/regression/compose/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,11 @@ def __init__(self, steps):

super().__init__()

tags_to_clone = ["capability:multioutput", "capability:survival"]
tags_to_clone = [
"capability:multioutput",
"capability:survival",
"capability:online",
]
self.clone_tags(self.regressor_, tags_to_clone)

@property
Expand Down Expand Up @@ -427,6 +431,38 @@ def _fit(self, X, y, C=None):

return self

def _update(self, X, y, C=None):
"""Update regressor with a new batch of training data.
State required:
Requires state to be "fitted" = self.is_fitted=True
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pd.DataFrame, must be same length as X
labels to fit regressor to
C : pd.DataFrame, optional (default=None)
censoring information for survival analysis,
should have same column name as y, same length as X and y
should have entries 0 and 1 (float or int)
0 = uncensored, 1 = (right) censored
if None, all observations are assumed to be uncensored
Can be passed to any probabilistic regressor,
but is ignored if capability:survival tag is False.
Returns
-------
self : reference to self
"""
X = self._transform(X)
self.regressor_.update(X=X, y=y, C=C)
return self

def _predict(self, X):
"""Predict labels for data from features.
Expand Down
7 changes: 7 additions & 0 deletions skpro/regression/online/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Meta-algorithms to build online regression models."""
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

from skpro.regression.online._dont_refit import OnlineDontRefit
from skpro.regression.online._refit import OnlineRefit

__all__ = ["OnlineDontRefit", "OnlineRefit"]
106 changes: 106 additions & 0 deletions skpro/regression/online/_dont_refit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Meta-strategy for online learning: turn off online update."""

__author__ = ["fkiraly"]
__all__ = ["OnlineDontRefit"]

from skpro.regression.base import _DelegatedProbaRegressor


class OnlineDontRefit(_DelegatedProbaRegressor):
"""Simple online regression strategy, turns off any refitting.
In ``fit``, behaves like the wrapped regressor.
In ``update``, does nothing, overriding any other logic.
This strategy is useful when the wrapped regressor is already an online regressor,
to create a "no-op" online regressor for comparison.
Parameters
----------
estimator : skpro regressor, descendant of BaseProbaRegressor
regressor to be update-refitted on all data, blueprint
Attributes
----------
estimator_ : skpro regressor, descendant of BaseProbaRegressor
clone of the regressor passed in the constructor, fitted on all data
"""

_tags = {"capability:online": False}

def __init__(self, estimator):
self.estimator = estimator

super().__init__()

tags_to_clone = [
"capability:missing",
"capability:survival",
]
self.clone_tags(estimator, tags_to_clone)

self.estimator_ = self.estimator.clone()

def _update(self, X, y, C=None):
"""Update regressor with new batch of training data.
State required:
Requires state to be "fitted".
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pandas DataFrame, must be same length as X
labels to fit regressor to
C : pd.DataFrame, optional (default=None)
censoring information for survival analysis,
should have same column name as y, same length as X and y
should have entries 0 and 1 (float or int)
0 = uncensored, 1 = (right) censored
if None, all observations are assumed to be uncensored
Can be passed to any probabilistic regressor,
but is ignored if capability:survival tag is False.
Returns
-------
self : reference to self
"""
return self

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sklearn.linear_model import LinearRegression

from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
from skpro.utils.validation._dependencies import _check_estimator_deps

regressor = ResidualDouble(LinearRegression())

params = [{"estimator": regressor}]

if _check_estimator_deps(CoxPH, severity="none"):
coxph = CoxPH()
params.append({"estimator": coxph})

return params
Loading

0 comments on commit ba2aae5

Please sign in to comment.