Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Improve sklearn compliance #13065

Merged
merged 12 commits into from
Jan 22, 2025
1 change: 1 addition & 0 deletions doc/changes/devel/13065.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved sklearn class compatibility and compliance, which resulted in some parameters of classes like :class:`mne.decoding.FilterEstimator` having an underscore appended to their name (e.g., ``picks`` passed to the initializer is set as ``est.picks_`` during the ``fit`` phase so the original can be preserved in ``est.picks``) by `Eric Larson`_.
2 changes: 1 addition & 1 deletion examples/decoding/linear_model_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

# Extract and plot spatial filters and spatial patterns
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fitted the linear model onto Z-scored data. To make the filters
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
coef = scaler.inverse_transform([coef])[0]

Expand Down
4 changes: 2 additions & 2 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array(
from .io import RawArray

return _compute_rank(
RawArray(data, info, copy=None, verbose=_verbose_safe_false()),
RawArray(data, info, copy="auto", verbose=_verbose_safe_false()),
rank,
scalings,
info,
Expand Down Expand Up @@ -1405,7 +1405,7 @@ def _compute_covariance_auto(
# project back
cov = np.dot(eigvec.T, np.dot(cov, eigvec))
# undo bias
cov *= data.shape[0] / (data.shape[0] - 1)
cov *= data.shape[0] / max(data.shape[0] - 1, 1)
# undo scaling
_undo_scaling_cov(cov, picks_list, scalings)
method_ = method[ei]
Expand Down
11 changes: 8 additions & 3 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, indexable
from sklearn.utils import check_array, check_X_y, indexable

from ..parallel import parallel_func
from ..utils import _pl, logger, verbose, warn
Expand Down Expand Up @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
if model is None:
model = LogisticRegression(solver="liblinear")

self.model = model

def __sklearn_tags__(self):
Expand Down Expand Up @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
X = check_array(X, input_name="X")
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
Expand All @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params):

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

Expand Down
99 changes: 43 additions & 56 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np
from scipy.linalg import eigh
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted

from .._fiff.meas_info import create_info
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
Expand All @@ -19,10 +20,11 @@
fill_doc,
pinv,
)
from .transformer import MNETransformerMixin


@fill_doc
class CSP(TransformerMixin, BaseEstimator):
class CSP(MNETransformerMixin, BaseEstimator):
"""M/EEG signal decomposition using the Common Spatial Patterns (CSP).

This class can be used as a supervised decomposition to estimate spatial
Expand Down Expand Up @@ -112,49 +114,44 @@ def __init__(
component_order="mutual_info",
):
# Init default CSP
if not isinstance(n_components, int):
raise ValueError("n_components must be an integer.")
self.n_components = n_components
self.rank = rank
self.reg = reg

# Init default cov_est
if not (cov_est == "concat" or cov_est == "epoch"):
raise ValueError("unknown covariance estimation method")
self.cov_est = cov_est

# Init default transform_into
self.transform_into = _check_option(
"transform_into", transform_into, ["average_power", "csp_space"]
)

# Init default log
if transform_into == "average_power":
if log is not None and not isinstance(log, bool):
raise ValueError(
'log must be a boolean if transform_into == "average_power".'
)
else:
if log is not None:
raise ValueError('log must be a None if transform_into == "csp_space".')
self.transform_into = transform_into
self.log = log

_validate_type(norm_trace, bool, "norm_trace")
self.norm_trace = norm_trace
self.cov_method_params = cov_method_params
self.component_order = _check_option(
"component_order", component_order, ("mutual_info", "alternate")
self.component_order = component_order

def _validate_params(self, *, y):
_validate_type(self.n_components, int, "n_components")
if hasattr(self, "cov_est"):
_validate_type(self.cov_est, str, "cov_est")
_check_option("cov_est", self.cov_est, ("concat", "epoch"))
if hasattr(self, "norm_trace"):
_validate_type(self.norm_trace, bool, "norm_trace")
_check_option(
"transform_into", self.transform_into, ["average_power", "csp_space"]
)

def _check_Xy(self, X, y=None):
"""Check input data."""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if y is not None:
if len(X) != len(y) or len(y) < 1:
raise ValueError("X and y must have the same length.")
if X.ndim < 3:
raise ValueError("X must have at least 3 dimensions.")
if self.transform_into == "average_power":
_validate_type(
self.log,
(bool, None),
"log",
extra="when transform_into is 'average_power'",
)
else:
_validate_type(
self.log, None, "log", extra="when transform_into is 'csp_space'"
)
_check_option(
"component_order", self.component_order, ("mutual_info", "alternate")
)
self.classes_ = np.unique(y)
n_classes = len(self.classes_)
if n_classes < 2:
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")

def fit(self, X, y):
"""Estimate the CSP decomposition on epochs.
Expand All @@ -171,12 +168,9 @@ def fit(self, X, y):
self : instance of CSP
Returns the modified instance.
"""
self._check_Xy(X, y)

self._classes = np.unique(y)
n_classes = len(self._classes)
if n_classes < 2:
raise ValueError("n_classes must be >= 2.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)
n_classes = len(self.classes_)
if n_classes > 2 and self.component_order == "alternate":
raise ValueError(
"component_order='alternate' requires two classes, but data contains "
Expand Down Expand Up @@ -225,13 +219,8 @@ def transform(self, X):
If self.transform_into == 'csp_space' then returns the data in CSP
space and shape is (n_epochs, n_components, n_times).
"""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if self.filters_ is None:
raise RuntimeError(
"No filters available. Please first fit CSP decomposition."
)

check_is_fitted(self, "filters_")
X = self._check_data(X)
pick_filters = self.filters_[: self.n_components]
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])

Expand Down Expand Up @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y):

covs = []
sample_weights = []
for ci, this_class in enumerate(self._classes):
for ci, this_class in enumerate(self.classes_):
cov, weight = cov_estimator(
X[y == this_class],
cov_kind=f"class={this_class}",
Expand Down Expand Up @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights):
def _order_components(
self, covs, sample_weights, eigen_vectors, eigen_values, component_order
):
n_classes = len(self._classes)
n_classes = len(self.classes_)
if component_order == "mutual_info" and n_classes > 2:
mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors)
ix = np.argsort(mutual_info)[::-1]
Expand Down Expand Up @@ -889,10 +878,8 @@ def fit(self, X, y):
self : instance of SPoC
Returns the modified instance.
"""
self._check_Xy(X, y)

if len(np.unique(y)) < 2:
raise ValueError("y must have at least two distinct values.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)

# The following code is directly copied from pyRiemann

Expand Down
23 changes: 18 additions & 5 deletions mne/decoding/ems.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from collections import Counter

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator

from .._fiff.pick import _picks_to_idx, pick_info, pick_types
from ..parallel import parallel_func
from ..utils import logger, verbose
from .base import _set_cv
from .transformer import MNETransformerMixin


class EMS(TransformerMixin, BaseEstimator):
class EMS(MNETransformerMixin, BaseEstimator):
"""Transformer to compute event-matched spatial filters.

This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire
Expand All @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator):
.. footbibliography::
"""

def __sklearn_tags__(self):
"""Return sklearn tags."""
from sklearn.utils import ClassifierTags

tags = super().__sklearn_tags__()
if tags.classifier_tags is None:
tags.classifier_tags = ClassifierTags()
tags.classifier_tags.multi_class = False
return tags

def __repr__(self): # noqa: D105
if hasattr(self, "filters_"):
return (
Expand Down Expand Up @@ -64,11 +75,12 @@ def fit(self, X, y):
self : instance of EMS
Returns self.
"""
classes = np.unique(y)
if len(classes) != 2:
X, y = self._check_data(X, y=y, fit=True, return_y=True)
classes, y = np.unique(y, return_inverse=True)
if len(classes) > 2:
raise ValueError("EMS only works for binary classification.")
self.classes_ = classes
filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0)
filters = X[y == 0].mean(0) - X[y == 1].mean(0)
filters /= np.linalg.norm(filters, axis=0)[None, :]
self.filters_ = filters
return self
Expand All @@ -86,6 +98,7 @@ def transform(self, X):
X : array, shape (n_epochs, n_times)
The input data transformed by the spatial filters.
"""
X = self._check_data(X)
Xt = np.sum(X * self.filters_, axis=1)
return Xt

Expand Down
Loading
Loading