diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..0dd2d6f 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -55,21 +55,35 @@ from sklearn.model_selection import BaseCrossValidator -from sklearn.utils.validation import check_X_y, check_is_fitted -from sklearn.utils.validation import check_array +from sklearn.utils.validation import (check_X_y, check_is_fitted, + validate_data) from sklearn.utils.multiclass import check_classification_targets from sklearn.metrics.pairwise import pairwise_distances -class KNearestNeighbors(BaseEstimator, ClassifierMixin): - """KNearestNeighbors classifier.""" +class KNearestNeighbors(ClassifierMixin, BaseEstimator): + """KNearestNeighbors classifier. + + This class implements a K-Nearest Neighbors classifier for classification + tasks. The classifier predicts the label of a test point based on the + majority class of its nearest neighbors in the training dataset. + + Parameters + ---------- + n_neighbors : int, default=1 + Number of neighbors to use for classification. + """ def __init__(self, n_neighbors=1): # noqa: D107 + """Initialize the classifier with the specified number of neighbors.""" self.n_neighbors = n_neighbors def fit(self, X, y): """Fitting function. + This method stores the training data and labels for later use + during prediction. + Parameters ---------- X : ndarray, shape (n_samples, n_features) @@ -80,8 +94,14 @@ def fit(self, X, y): Returns ---------- self : instance of KNearestNeighbors - The current instance of the classifier + The fitted instance of the classifier """ + X, y = check_X_y(X, y) + self.X_train_ = X + self.y_train_ = y + self.n_features_in_ = X.shape[1] + check_classification_targets(y) + self.classes_ = np.unique(y) return self def predict(self, X): @@ -97,7 +117,15 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self) + X = validate_data(self, X, reset=False) + distances = pairwise_distances(X, self.X_train_) + nearest_neighbors = np.argsort(distances, axis=1)[:, :self.n_neighbors] + unique_classes, y_indices = np.unique(self.y_train_, + return_inverse=True) + neighbor_labels = y_indices[nearest_neighbors] + y_pred = np.array([unique_classes[np.bincount(labels).argmax()] + for labels in neighbor_labels]) return y_pred def score(self, X, y): @@ -113,9 +141,11 @@ def score(self, X, y): Returns ---------- score : float - Accuracy of the model computed for the (X, y) pairs. + Accuracy of the model computed as the + mean for correctly predicted labels. """ - return 0. + y_pred = self.predict(X) + return np.mean(y_pred == y) class MonthlySplit(BaseCrossValidator): @@ -153,9 +183,19 @@ def get_n_splits(self, X, y=None, groups=None): Returns ------- n_splits : int - The number of splits. + The number of splits based on unique months in the data. """ - return 0 + if isinstance(X, pd.Series): + times = X.index + elif isinstance(X, pd.DataFrame): + times = X.index if self.time_col == 'index' else X[self.time_col] + else: + raise ValueError("X should be a pandas DataFrame or Series.") + + if not pd.api.types.is_datetime64_any_dtype(times): + raise ValueError("time_col must be a datetime column.") + periods = pd.Series(times).dt.to_period("M") + return len(periods.unique()) - 1 def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +217,49 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ + # Determine time column + if isinstance(X, pd.DataFrame): + if self.time_col == 'index': + times = X.index + else: + times = X[self.time_col] + elif isinstance(X, pd.Series): + times = X.index + else: + raise ValueError("X should be a pandas DataFrame or Series.") + + # Ensure time column is datetime + if not pd.api.types.is_datetime64_any_dtype(times): + raise ValueError("time_col must be a datetime column.") + + # Create a copy of the data + X_copy = X.copy() + y_copy = y.copy() if y is not None else None + + # Sort the copy of the data by time + if isinstance(X_copy, pd.DataFrame) and self.time_col != 'index': + sorted_data = X_copy.sort_values(by=self.time_col) + else: + sorted_data = X_copy.sort_index() + + # Extract the sorted indices + sorted_indices = sorted_data.index + + # Map sorted indices to original indices + times = pd.Series(times.values, index=sorted_indices).sort_index() + + # Sort y_copy if it exists + if y_copy is not None: + y_copy = y_copy.loc[sorted_indices] + + # Group by unique months + periods = times.dt.to_period("M") + unique_periods = sorted(periods.unique()) + + n_splits = self.get_n_splits(X_copy, y_copy, groups) - n_samples = X.shape[0] - n_splits = self.get_n_splits(X, y, groups) for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test - ) + idx_train = np.where(periods == unique_periods[i])[0] + idx_test = np.where(periods == unique_periods[i + 1])[0] + + yield idx_train, idx_test