Skip to content

fix: Subclass sklearn.model_selection._RepeatedSplits and BaseShuffleSplit from BaseCrossValidator #349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions stubs/sklearn/calibration.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ from .base import (
)
from .isotonic import IsotonicRegression
from .model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_predict as cross_val_predict
from .model_selection._split import BaseShuffleSplit
from .preprocessing import LabelEncoder as LabelEncoder, label_binarize as label_binarize
from .svm import LinearSVC as LinearSVC
from .utils import check_matplotlib_support as check_matplotlib_support, column_or_1d as column_or_1d, indexable as indexable
Expand Down Expand Up @@ -51,7 +50,7 @@ class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator)
estimator: None | BaseEstimator = None,
*,
method: Literal["sigmoid", "isotonic"] = "sigmoid",
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None | str = None,
n_jobs: None | Int = None,
ensemble: bool = True,
base_estimator: str | BaseEstimator = "deprecated",
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/covariance/_graph_lasso.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
from ..exceptions import ConvergenceWarning as ConvergenceWarning
from ..linear_model import lars_path_gram as lars_path_gram
from ..model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_score as cross_val_score
from ..model_selection._split import BaseShuffleSplit
from ..utils._param_validation import Interval as Interval, StrOptions as StrOptions
from ..utils.parallel import Parallel as Parallel, delayed as delayed
from ..utils.validation import check_random_state as check_random_state, check_scalar as check_scalar
Expand Down Expand Up @@ -117,7 +116,7 @@ class GraphicalLassoCV(BaseGraphicalLasso):
*,
alphas: ArrayLike | int = 4,
n_refinements: Int = 4,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
tol: Float = 1e-4,
enet_tol: Float = 1e-4,
max_iter: Int = 100,
Expand Down
5 changes: 2 additions & 3 deletions stubs/sklearn/ensemble/_stacking.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ from ..exceptions import NotFittedError as NotFittedError
from ..linear_model._logistic import LogisticRegression
from ..linear_model._ridge import RidgeCV
from ..model_selection import BaseCrossValidator, check_cv as check_cv, cross_val_predict as cross_val_predict
from ..model_selection._split import BaseShuffleSplit
from ..pipeline import Pipeline
from ..preprocessing import LabelEncoder as LabelEncoder
from ..utils import Bunch
Expand Down Expand Up @@ -78,7 +77,7 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
estimators: Sequence[tuple[str, BaseEstimator]],
final_estimator: None | BaseEstimator | LogisticRegression = None,
*,
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None | str = None,
stack_method: Literal["auto", "predict_proba", "decision_function", "predict"] = "auto",
n_jobs: None | Int = None,
passthrough: bool = False,
Expand Down Expand Up @@ -108,7 +107,7 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
estimators: Sequence[tuple[str, BaseEstimator]] | list[tuple[str, Pipeline]],
final_estimator: None | BaseEstimator | RidgeCV = None,
*,
cv: int | BaseCrossValidator | Iterable | None | str | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None | str = None,
n_jobs: None | Int = None,
passthrough: bool = False,
verbose: Int = 0,
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/feature_selection/_rfe.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ from ..base import BaseEstimator, MetaEstimatorMixin, clone as clone, is_classif
from ..linear_model._logistic import LogisticRegression
from ..metrics import check_scoring as check_scoring
from ..model_selection import BaseCrossValidator, check_cv as check_cv
from ..model_selection._split import BaseShuffleSplit
from ..utils._param_validation import HasMethods as HasMethods, Interval as Interval
from ..utils.metaestimators import available_if as available_if
from ..utils.parallel import Parallel as Parallel, delayed as delayed
Expand Down Expand Up @@ -73,7 +72,7 @@ class RFECV(RFE):
*,
step: float = 1,
min_features_to_select: Int = 1,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
scoring: None | str | Callable = None,
verbose: Int = 0,
n_jobs: None | int = None,
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/feature_selection/_sequential.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import BaseEstimator, MetaEstimatorMixin, clone as clone
from ..metrics import get_scorer_names as get_scorer_names
from ..model_selection import BaseCrossValidator, cross_val_score as cross_val_score
from ..model_selection._split import BaseShuffleSplit
from ..utils._param_validation import HasMethods as HasMethods, Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
from ..utils.validation import check_is_fitted as check_is_fitted
from ._base import SelectorMixin
Expand All @@ -34,7 +33,7 @@ class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator
tol: None | Float = None,
direction: Literal["forward", "backward"] = "forward",
scoring: None | str | Callable = None,
cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5,
cv: Iterable | int | BaseCrossValidator = 5,
n_jobs: None | Int = None,
) -> None: ...
def fit(
Expand Down
9 changes: 4 additions & 5 deletions stubs/sklearn/linear_model/_coordinate_descent.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ from scipy.sparse._coo import coo_matrix
from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import MultiOutputMixin, RegressorMixin
from ..model_selection import BaseCrossValidator, check_cv as check_cv
from ..model_selection._split import BaseShuffleSplit
from ..utils import check_array as check_array, check_scalar as check_scalar
from ..utils._param_validation import Interval as Interval, StrOptions as StrOptions
from ..utils.extmath import safe_sparse_dot as safe_sparse_dot
Expand Down Expand Up @@ -206,7 +205,7 @@ class LassoCV(RegressorMixin, LinearModelCV):
max_iter: Int = 1000,
tol: Float = 1e-4,
copy_X: bool = True,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
verbose: int | bool = False,
n_jobs: None | Int = None,
positive: bool = False,
Expand Down Expand Up @@ -241,7 +240,7 @@ class ElasticNetCV(RegressorMixin, LinearModelCV):
precompute: Literal["auto"] | MatrixLike | bool = "auto",
max_iter: Int = 1000,
tol: Float = 1e-4,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
copy_X: bool = True,
verbose: int | bool = 0,
n_jobs: None | Int = None,
Expand Down Expand Up @@ -333,7 +332,7 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
fit_intercept: bool = True,
max_iter: Int = 1000,
tol: Float = 1e-4,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
copy_X: bool = True,
verbose: int | bool = 0,
n_jobs: None | Int = None,
Expand Down Expand Up @@ -370,7 +369,7 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
max_iter: Int = 1000,
tol: Float = 1e-4,
copy_X: bool = True,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
verbose: int | bool = False,
n_jobs: None | Int = None,
random_state: RandomState | None | Int = None,
Expand Down
5 changes: 2 additions & 3 deletions stubs/sklearn/linear_model/_least_angle.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import MultiOutputMixin, RegressorMixin
from ..exceptions import ConvergenceWarning as ConvergenceWarning
from ..model_selection import BaseCrossValidator, check_cv as check_cv
from ..model_selection._split import BaseShuffleSplit
from ..utils import arrayfuncs as arrayfuncs, as_float_array as as_float_array, check_random_state as check_random_state
from ..utils._param_validation import Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
from ..utils.parallel import Parallel as Parallel, delayed as delayed
Expand Down Expand Up @@ -160,7 +159,7 @@ class LarsCV(Lars):
max_iter: Int = 500,
normalize: str | bool = "deprecated",
precompute: Literal["auto"] | ArrayLike | bool = "auto",
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
max_n_alphas: Int = 1000,
n_jobs: None | int = None,
eps: Float = ...,
Expand Down Expand Up @@ -193,7 +192,7 @@ class LassoLarsCV(LarsCV):
max_iter: Int = 500,
normalize: str | bool = "deprecated",
precompute: Literal["auto"] | bool = "auto",
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
max_n_alphas: Int = 1000,
n_jobs: None | int = None,
eps: Float = ...,
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/linear_model/_logistic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ from .._loss.loss import HalfBinomialLoss as HalfBinomialLoss, HalfMultinomialLo
from .._typing import ArrayLike, Float, Int, MatrixLike
from ..metrics import get_scorer as get_scorer, get_scorer_names as get_scorer_names
from ..model_selection import BaseCrossValidator, check_cv as check_cv
from ..model_selection._split import BaseShuffleSplit
from ..preprocessing import LabelBinarizer as LabelBinarizer, LabelEncoder as LabelEncoder
from ..utils import (
check_array as check_array,
Expand Down Expand Up @@ -108,7 +107,7 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
*,
Cs: Sequence[float] | int = 10,
fit_intercept: bool = True,
cv: int | None | BaseShuffleSplit | BaseCrossValidator = None,
cv: int | None | BaseCrossValidator = None,
dual: bool = False,
penalty: Literal["l1", "l2", "elasticnet"] = "l2",
scoring: None | str | Callable = None,
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/linear_model/_omp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ from scipy.linalg.lapack import get_lapack_funcs as get_lapack_funcs
from .._typing import ArrayLike, Float, Int, MatrixLike
from ..base import MultiOutputMixin, RegressorMixin
from ..model_selection import BaseCrossValidator, check_cv as check_cv
from ..model_selection._split import BaseShuffleSplit
from ..utils import as_float_array as as_float_array, check_array as check_array
from ..utils._param_validation import Hidden as Hidden, Interval as Interval, StrOptions as StrOptions
from ..utils.parallel import Parallel as Parallel, delayed as delayed
Expand Down Expand Up @@ -90,7 +89,7 @@ class OrthogonalMatchingPursuitCV(RegressorMixin, LinearModel):
fit_intercept: bool = True,
normalize: str | bool = "deprecated",
max_iter: None | Int = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
n_jobs: None | Int = None,
verbose: int | bool = False,
) -> None: ...
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/linear_model/_ridge.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ from ..base import MultiOutputMixin, RegressorMixin, is_classifier as is_classif
from ..exceptions import ConvergenceWarning as ConvergenceWarning
from ..metrics import check_scoring as check_scoring, get_scorer_names as get_scorer_names
from ..model_selection import BaseCrossValidator, GridSearchCV as GridSearchCV
from ..model_selection._split import BaseShuffleSplit
from ..preprocessing import LabelBinarizer as LabelBinarizer
from ..utils import (
check_array as check_array,
Expand Down Expand Up @@ -237,7 +236,7 @@ class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
*,
fit_intercept: bool = True,
scoring: None | str | Callable = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
class_weight: None | Mapping | str = None,
store_cv_values: bool = False,
) -> None: ...
Expand Down
3 changes: 1 addition & 2 deletions stubs/sklearn/model_selection/_plot.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ from numpy.random import RandomState
from .._typing import ArrayLike, Float, Int, MatrixLike
from ..utils import check_matplotlib_support as check_matplotlib_support
from . import BaseCrossValidator, learning_curve as learning_curve
from ._split import BaseShuffleSplit

class LearningCurveDisplay:
fill_between_: Artist | None = ...
Expand Down Expand Up @@ -48,7 +47,7 @@ class LearningCurveDisplay:
*,
groups: None | ArrayLike = None,
train_sizes: ArrayLike = ...,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
scoring: None | str | Callable = None,
exploit_incremental_learning: bool = False,
n_jobs: None | Int = None,
Expand Down
6 changes: 3 additions & 3 deletions stubs/sklearn/model_selection/_search.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from ..utils.parallel import Parallel as Parallel, delayed as delayed
from ..utils.random import sample_without_replacement as sample_without_replacement
from ..utils.validation import check_is_fitted as check_is_fitted, indexable as indexable
from . import BaseCrossValidator
from ._split import BaseShuffleSplit, check_cv as check_cv
from ._split import check_cv as check_cv

BaseSearchCV_Self = TypeVar("BaseSearchCV_Self", bound=BaseSearchCV)
BaseEstimatorT = TypeVar("BaseEstimatorT", bound=BaseEstimator, default=BaseEstimator, covariant=True)
Expand Down Expand Up @@ -110,7 +110,7 @@ class GridSearchCV(BaseSearchCV, Generic[BaseEstimatorT]):
scoring: ArrayLike | None | tuple | Callable | Mapping | str = None,
n_jobs: None | Int = None,
refit: str | Callable | bool = True,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
verbose: Int = 0,
pre_dispatch: str | int = "2*n_jobs",
error_score: str | Float = ...,
Expand Down Expand Up @@ -142,7 +142,7 @@ class RandomizedSearchCV(BaseSearchCV):
scoring: ArrayLike | None | tuple | Callable | Mapping | str = None,
n_jobs: None | Int = None,
refit: str | Callable | bool = True,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
verbose: Int = 0,
pre_dispatch: str | int = "2*n_jobs",
random_state: RandomState | None | Int = None,
Expand Down
6 changes: 3 additions & 3 deletions stubs/sklearn/model_selection/_search_successive_halving.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from ..utils import resample as resample
from ..utils.multiclass import check_classification_targets as check_classification_targets
from . import BaseCrossValidator, ParameterGrid as ParameterGrid, ParameterSampler as ParameterSampler
from ._search import BaseSearchCV
from ._split import BaseShuffleSplit, check_cv as check_cv
from ._split import check_cv as check_cv

BaseSuccessiveHalving_Self = TypeVar("BaseSuccessiveHalving_Self", bound=BaseSuccessiveHalving)

Expand Down Expand Up @@ -87,7 +87,7 @@ class HalvingGridSearchCV(BaseSuccessiveHalving):
max_resources: str | Int = "auto",
min_resources: int | Literal["exhaust", "smallest"] = "exhaust",
aggressive_elimination: bool = False,
cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5,
cv: Iterable | int | BaseCrossValidator = 5,
scoring: None | str | Callable = None,
refit: bool = True,
error_score: str | Float = ...,
Expand Down Expand Up @@ -132,7 +132,7 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
max_resources: str | Int = "auto",
min_resources: Literal["exhaust", "smallest"] | int = "smallest",
aggressive_elimination: bool = False,
cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator = 5,
cv: Iterable | int | BaseCrossValidator = 5,
scoring: None | str | Callable = None,
refit: bool = True,
error_score: str | Float = ...,
Expand Down
8 changes: 4 additions & 4 deletions stubs/sklearn/model_selection/_split.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BaseCrossValidator(metaclass=ABCMeta):
groups: None | ArrayLike = None,
) -> Iterator[tuple[ndarray, ndarray]]: ...
@abstractmethod
def get_n_splits(self, X=None, y=None, groups=None): ...
def get_n_splits(self, X=None, y=None, groups=None) -> int: ...

class LeaveOneOut(BaseCrossValidator):
def get_n_splits(self, X: MatrixLike, y: Any = None, groups: Any = None) -> int: ...
Expand Down Expand Up @@ -119,7 +119,7 @@ class LeavePGroupsOut(BaseCrossValidator):
def get_n_splits(self, X: Any = None, y: Any = None, groups: None | ArrayLike = None) -> int: ...
def split(self, X: MatrixLike, y: None | ArrayLike = None, groups: None | ArrayLike = None): ...

class _RepeatedSplits(metaclass=ABCMeta):
class _RepeatedSplits(BaseCrossValidator, metaclass=ABCMeta):
def __init__(
self,
cv: Callable,
Expand Down Expand Up @@ -151,7 +151,7 @@ class RepeatedStratifiedKFold(_RepeatedSplits):
random_state: RandomState | None | Int = None,
) -> None: ...

class BaseShuffleSplit(metaclass=ABCMeta):
class BaseShuffleSplit(BaseCrossValidator, metaclass=ABCMeta):
def __init__(self, n_splits: int = 10, *, test_size=None, train_size=None, random_state=None) -> None: ...
def split(
self, X: MatrixLike, y: None | ArrayLike = None, groups: None | ArrayLike = None
Expand Down Expand Up @@ -201,7 +201,7 @@ class _CVIterableWrapper(BaseCrossValidator):
def split(self, X: Any = None, y: Any = None, groups: Any = None): ...

def check_cv(
cv: Iterable | int | BaseShuffleSplit | BaseCrossValidator | None = 5,
cv: Iterable | int | BaseCrossValidator | None = 5,
y: None | ArrayLike = None,
*,
classifier: bool = False,
Expand Down
14 changes: 7 additions & 7 deletions stubs/sklearn/model_selection/_validation.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ from ..svm._classes import SVC, LinearSVC
from ..utils import check_random_state as check_random_state, indexable as indexable
from ..utils.parallel import Parallel as Parallel, delayed as delayed
from . import BaseCrossValidator
from ._split import BaseShuffleSplit, check_cv as check_cv
from ._split import check_cv as check_cv

# Author: Alexandre Gramfort <[email protected]>
# Gael Varoquaux <[email protected]>
Expand All @@ -47,7 +47,7 @@ def cross_validate(
*,
groups: None | ArrayLike = None,
scoring: ArrayLike | None | tuple | Callable | Mapping | str = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
n_jobs: None | Int = None,
verbose: Int = 0,
fit_params: None | dict = None,
Expand All @@ -63,7 +63,7 @@ def cross_val_score(
*,
groups: None | ArrayLike = None,
scoring: None | str | Callable = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
n_jobs: None | Int = None,
verbose: Int = 0,
fit_params: None | dict = None,
Expand All @@ -76,7 +76,7 @@ def cross_val_predict(
y: None | MatrixLike | ArrayLike = None,
*,
groups: None | ArrayLike = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
n_jobs: None | Int = None,
verbose: Int = 0,
fit_params: None | dict = None,
Expand All @@ -89,7 +89,7 @@ def permutation_test_score(
y: None | MatrixLike | ArrayLike,
*,
groups: None | ArrayLike = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
n_permutations: Int = 100,
n_jobs: None | Int = None,
random_state: RandomState | None | Int = 0,
Expand All @@ -104,7 +104,7 @@ def learning_curve(
*,
groups: None | ArrayLike = None,
train_sizes: ArrayLike = ...,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
scoring: None | str | Callable = None,
exploit_incremental_learning: bool = False,
n_jobs: None | Int = None,
Expand All @@ -124,7 +124,7 @@ def validation_curve(
param_name: str,
param_range: ArrayLike,
groups: None | ArrayLike = None,
cv: int | BaseCrossValidator | Iterable | None | BaseShuffleSplit = None,
cv: int | BaseCrossValidator | Iterable | None = None,
scoring: None | str | Callable = None,
n_jobs: None | Int = None,
pre_dispatch: str | Int = "all",
Expand Down
Loading