diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..1ee6e51 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -55,22 +55,35 @@ from sklearn.model_selection import BaseCrossValidator +from sklearn.utils.validation import check_is_fitted from sklearn.utils.validation import check_X_y, check_is_fitted from sklearn.utils.validation import check_array -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.multiclass import check_classification_targets, unique_labels from sklearn.metrics.pairwise import pairwise_distances - class KNearestNeighbors(BaseEstimator, ClassifierMixin): - """KNearestNeighbors classifier.""" + """ + K-Nearest Neighbors classifier. + + This class implements a simple K-Nearest Neighbors classifier. + """ - def __init__(self, n_neighbors=1): # noqa: D107 + def __init__(self, n_neighbors=5): + """ + Initialize the KNearestNeighbors classifier. + + Parameters + ---------- + n_neighbors : int, optional (default=5) + Number of neighbors to use for k-neighbors queries. + """ self.n_neighbors = n_neighbors def fit(self, X, y): - """Fitting function. + """ + Fit the K-Nearest Neighbors classifier from the training dataset. - Parameters + Parameters ---------- X : ndarray, shape (n_samples, n_features) Data to train the model. @@ -78,44 +91,65 @@ def fit(self, X, y): Labels associated with the training data. Returns - ---------- + ------- self : instance of KNearestNeighbors - The current instance of the classifier + The current instance of the classifier. """ + # Validation of data + X, y = check_X_y(X, y, dtype=[np.float64, np.int64]) + check_classification_targets(y) + self.X_train = X + self.y_train = y + self.n_features_in_ = X.shape[1] return self def predict(self, X): - """Predict function. + """ + Predict the class labels for the provided data. Parameters ---------- - X : ndarray, shape (n_test_samples, n_features) - Data to predict on. + X : ndarray, shape (n_samples, n_features) + Data to predict the class labels for. Returns - ---------- - y : ndarray, shape (n_test_samples,) - Predicted class labels for each test data sample. + ------- + y_pred : ndarray, shape (n_samples,) + Predicted class labels for each data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self, ['X_train', 'y_train']) + X = check_array(X, dtype=[np.float64, np.int64]) + if X.shape[1] != self.n_features_in_: + raise ValueError(f"Number of features of the input must be {self.n_features_in_}") + + y_pred = np.empty(X.shape[0], dtype=self.y_train.dtype) + + for i, x in enumerate(X): + distances = np.linalg.norm(self.X_train - x, axis=1) + nearest_neighbors = np.argsort(distances)[:self.n_neighbors] + nearest_labels = self.y_train[nearest_neighbors] + y_pred[i] = np.bincount(nearest_labels).argmax() + return y_pred def score(self, X, y): - """Calculate the score of the prediction. + """ + Calculate the score of the prediction. Parameters ---------- X : ndarray, shape (n_samples, n_features) Data to score on. y : ndarray, shape (n_samples,) - target values. + True labels for X. Returns - ---------- + ------- score : float - Accuracy of the model computed for the (X, y) pairs. + Mean accuracy of the classifier on the given test data and labels. """ - return 0. + y_pred = self.predict(X) + return np.mean(y_pred == y) class MonthlySplit(BaseCrossValidator): @@ -155,7 +189,13 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + + if self.time_col == 'index': + time_column = pd.Series(X.index) + else: + time_column = X[self.time_col] + + return len(time_column.dt.to_period('M').unique()) - 1 def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -178,11 +218,22 @@ def split(self, X, y, groups=None): The testing set indices for that split. """ - n_samples = X.shape[0] + if self.time_col == 'index': + time_column = pd.Series(X.index) + else: + time_column = X[self.time_col] + + if not np.issubdtype(time_column.dtype, np.datetime64): + raise ValueError( + f"Column '{self.time_col}' is not of type datetime64." + ) + + monthly_periods = time_column.dt.to_period('M') + unique_periods = sorted(monthly_periods.unique()) n_splits = self.get_n_splits(X, y, groups) + for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) - yield ( - idx_train, idx_test - ) + train_idx = np.where(monthly_periods == unique_periods[i])[0] + test_idx = np.where(monthly_periods == unique_periods[i + 1])[0] + yield train_idx, test_idx +