diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..8ff3e21 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -55,13 +55,12 @@ from sklearn.model_selection import BaseCrossValidator -from sklearn.utils.validation import check_X_y, check_is_fitted -from sklearn.utils.validation import check_array +from sklearn.utils.validation import check_X_y, check_is_fitted, validate_data from sklearn.utils.multiclass import check_classification_targets from sklearn.metrics.pairwise import pairwise_distances -class KNearestNeighbors(BaseEstimator, ClassifierMixin): +class KNearestNeighbors(ClassifierMixin, BaseEstimator): """KNearestNeighbors classifier.""" def __init__(self, n_neighbors=1): # noqa: D107 @@ -82,6 +81,12 @@ def fit(self, X, y): self : instance of KNearestNeighbors The current instance of the classifier """ + X, y = check_X_y(X, y) + self._X = X + self._y = y + self.n_features_in_ = X.shape[1] + check_classification_targets(y) + self.classes_ = np.unique(y) return self def predict(self, X): @@ -97,7 +102,15 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self) + X = validate_data(self, X, reset=False) + distances = pairwise_distances(X, self._X) + nearest_neighbors = np.argsort(distances, axis=1)[:, :self.n_neighbors] + unique_classes, y_indices = np.unique(self._y, + return_inverse=True) + neighbor_labels = y_indices[nearest_neighbors] + y_pred = np.array([unique_classes[np.bincount(labels).argmax()] + for labels in neighbor_labels]) return y_pred def score(self, X, y): @@ -115,7 +128,8 @@ def score(self, X, y): score : float Accuracy of the model computed for the (X, y) pairs. """ - return 0. + y_pred = self.predict(X) + return np.mean(y_pred == y) class MonthlySplit(BaseCrossValidator): @@ -155,7 +169,17 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + if isinstance(X, pd.Series): + times = X.index + elif isinstance(X, pd.DataFrame): + times = X.index if self.time_col == 'index' else X[self.time_col] + else: + raise ValueError("X should be a pandas DataFrame or Series.") + + if not pd.api.types.is_datetime64_any_dtype(times): + raise ValueError("time_col must be a datetime column.") + periods = pd.Series(times).dt.to_period("M") + return len(periods.unique()) - 1 def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +201,35 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ - - n_samples = X.shape[0] - n_splits = self.get_n_splits(X, y, groups) - for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test - ) + if isinstance(X, pd.Series): + times = X.index + elif isinstance(X, pd.DataFrame): + times = X.index if self.time_col == "index" else X[self.time_col] + else: + raise ValueError("X should be a pandas DataFrame or Series.") + + if not pd.api.types.is_datetime64_any_dtype(times): + raise ValueError("time_col must be a datetime column.") + sorted_indices = np.argsort(times) + times = times[sorted_indices] + X = X.iloc[sorted_indices] if isinstance( + X, pd.DataFrame + ) else X[ + sorted_indices + ] + y = y.iloc[sorted_indices] if isinstance( + y, pd.DataFrame + ) else y[ + sorted_indices + ] + + periods = pd.Series(times).dt.to_period("M") + unique_months = periods.unique() + for i in range(len(unique_months) - 1): + train_month = unique_months[i] + test_month = unique_months[i + 1] + + idx_train = np.array(periods[periods == train_month].index) + idx_test = np.array(periods[periods == test_month].index) + + yield idx_train, idx_test