Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions python/cuml/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -434,22 +434,22 @@ class DBSCAN(UniversalBase,

@generate_docstring(skip_parameters_heading=True)
@enable_device_interop
def fit(self, X, y=None, out_dtype="int32", sample_weight=None,
def fit(self, X, y=None, sample_weight=None, out_dtype="int32",
convert_dtype=True) -> "DBSCAN":
"""
Perform DBSCAN clustering from features.

Parameters
----------
out_dtype: dtype Determines the precision of the output labels array.
default: "int32". Valid values are { "int32", np.int32,
"int64", np.int64}.

sample_weight: array-like of shape (n_samples,), default=None
Weight of each sample, such that a sample with a weight of at
least min_samples is by itself a core sample; a sample with a
negative weight may inhibit its eps-neighbor from being core.
default: None (which is equivalent to weight 1 for all samples).

out_dtype: dtype Determines the precision of the output labels array.
default: "int32". Valid values are { "int32", np.int32,
"int64", np.int64}.
"""
return self._fit(X, out_dtype, False, sample_weight)

Expand All @@ -459,21 +459,21 @@ class DBSCAN(UniversalBase,
'description': 'Cluster labels',
'shape': '(n_samples, 1)'})
@enable_device_interop
def fit_predict(self, X, y=None, out_dtype="int32", sample_weight=None) -> CumlArray:
def fit_predict(self, X, y=None, sample_weight=None, out_dtype="int32") -> CumlArray:
"""
Performs clustering on X and returns cluster labels.

Parameters
----------
out_dtype: dtype Determines the precision of the output labels array.
default: "int32". Valid values are { "int32", np.int32,
"int64", np.int64}.

sample_weight: array-like of shape (n_samples,), default=None
Weight of each sample, such that a sample with a weight of at
least min_samples is by itself a core sample; a sample with a
negative weight may inhibit its eps-neighbor from being core.
default: None (which is equivalent to weight 1 for all samples).

out_dtype: dtype Determines the precision of the output labels array.
default: "int32". Valid values are { "int32", np.int32,
"int64", np.int64}.
"""
self.fit(X, out_dtype=out_dtype, sample_weight=sample_weight)
return self.labels_
Expand Down
3 changes: 1 addition & 2 deletions python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,7 @@ class KMeans(UniversalBase,
'description': 'Transformed data',
'shape': '(n_samples, n_clusters)'})
@enable_device_interop
def fit_transform(self, X, y=None, convert_dtype=False,
sample_weight=None) -> CumlArray:
def fit_transform(self, X, y=None, sample_weight=None, convert_dtype=False) -> CumlArray:
"""
Compute clustering and transform X to cluster-distance space.

Expand Down
3 changes: 1 addition & 2 deletions python/cuml/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,7 @@ class ElasticNet(Base,

@generate_docstring()
@warn_legacy_device_interop
def fit(self, X, y, convert_dtype=True,
sample_weight=None) -> "ElasticNet":
def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "ElasticNet":
"""
Fit the model with X and y.

Expand Down
3 changes: 1 addition & 2 deletions python/cuml/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ class LinearRegression(Base,

@generate_docstring()
@warn_legacy_device_interop
def fit(self, X, y, convert_dtype=True,
sample_weight=None) -> "LinearRegression":
def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "LinearRegression":
"""
Fit the model with X and y.

Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class Ridge(Base,

@generate_docstring()
@warn_legacy_device_interop
def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "Ridge":
def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "Ridge":
"""
Fit the model with X and y.
"""
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cuml/solvers/cd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class CD(Base,
}[self.loss]

@generate_docstring()
def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "CD":
def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "CD":
"""
Fit the model with X and y.

Expand Down
15 changes: 13 additions & 2 deletions python/cuml/cuml/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_base_children__get_param_names(child_class: str):
}
],
)
def test_sklearn_methods_with_required_y_parameter(cls):
def test_sklearn_methods_match_sklearn_interface(cls):
optional_params = {
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
Expand All @@ -229,13 +229,24 @@ def test_sklearn_methods_with_required_y_parameter(cls):
if (method := getattr(cls, name, None)) is None:
# Method not defined, skip
continue
params = list(inspect.signature(method).parameters.values())
sig = inspect.signature(method)
params = list(sig.parameters.values())
# Assert method has a 2nd parameter named y, which is required by sklearn
assert (
len(params) > 2 and params[2].name == "y"
), f"`{name}` requires a `y` parameter, even if it's ignored"

# Check that all remaining parameters are optional
for param in params[3:]:
assert (
param.kind in optional_params
), f"`{name}` parameter `{param.name}` must be optional"

# Check common kwargs have an expected order
i = 3
for kw_name in ["classes", "sample_weight"]:
if kw_name in sig.parameters:
assert (
params[i].name == kw_name
), f"`{kw_name}` should be before {params[i].name} in `{name}`"
i += 1
Loading