Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 78 additions & 27 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,67 +55,101 @@

from sklearn.model_selection import BaseCrossValidator

from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.multiclass import check_classification_targets, unique_labels
from sklearn.metrics.pairwise import pairwise_distances


class KNearestNeighbors(BaseEstimator, ClassifierMixin):
"""KNearestNeighbors classifier."""
"""
K-Nearest Neighbors classifier.

This class implements a simple K-Nearest Neighbors classifier.
"""

def __init__(self, n_neighbors=1): # noqa: D107
def __init__(self, n_neighbors=5):
"""
Initialize the KNearestNeighbors classifier.

Parameters
----------
n_neighbors : int, optional (default=5)
Number of neighbors to use for k-neighbors queries.
"""
self.n_neighbors = n_neighbors

def fit(self, X, y):
"""Fitting function.
"""
Fit the K-Nearest Neighbors classifier from the training dataset.

Parameters
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Data to train the model.
y : ndarray, shape (n_samples,)
Labels associated with the training data.

Returns
----------
-------
self : instance of KNearestNeighbors
The current instance of the classifier
The current instance of the classifier.
"""
# Validation of data
X, y = check_X_y(X, y, dtype=[np.float64, np.int64])
check_classification_targets(y)
self.X_train = X
self.y_train = y
self.n_features_in_ = X.shape[1]
return self

def predict(self, X):
"""Predict function.
"""
Predict the class labels for the provided data.

Parameters
----------
X : ndarray, shape (n_test_samples, n_features)
Data to predict on.
X : ndarray, shape (n_samples, n_features)
Data to predict the class labels for.

Returns
----------
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
-------
y_pred : ndarray, shape (n_samples,)
Predicted class labels for each data sample.
"""
y_pred = np.zeros(X.shape[0])
check_is_fitted(self, ['X_train', 'y_train'])
X = check_array(X, dtype=[np.float64, np.int64])
if X.shape[1] != self.n_features_in_:
raise ValueError(f"Number of features of the input must be {self.n_features_in_}")

y_pred = np.empty(X.shape[0], dtype=self.y_train.dtype)

for i, x in enumerate(X):
distances = np.linalg.norm(self.X_train - x, axis=1)
nearest_neighbors = np.argsort(distances)[:self.n_neighbors]
nearest_labels = self.y_train[nearest_neighbors]
y_pred[i] = np.bincount(nearest_labels).argmax()

return y_pred

def score(self, X, y):
"""Calculate the score of the prediction.
"""
Calculate the score of the prediction.

Parameters
----------
X : ndarray, shape (n_samples, n_features)
Data to score on.
y : ndarray, shape (n_samples,)
target values.
True labels for X.

Returns
----------
-------
score : float
Accuracy of the model computed for the (X, y) pairs.
Mean accuracy of the classifier on the given test data and labels.
"""
return 0.
y_pred = self.predict(X)
return np.mean(y_pred == y)


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +189,13 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0

if self.time_col == 'index':
time_column = pd.Series(X.index)
else:
time_column = X[self.time_col]

return len(time_column.dt.to_period('M').unique()) - 1

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -178,11 +218,22 @@ def split(self, X, y, groups=None):
The testing set indices for that split.
"""

n_samples = X.shape[0]
if self.time_col == 'index':
time_column = pd.Series(X.index)
else:
time_column = X[self.time_col]

if not np.issubdtype(time_column.dtype, np.datetime64):
raise ValueError(
f"Column '{self.time_col}' is not of type datetime64."
)

monthly_periods = time_column.dt.to_period('M')
unique_periods = sorted(monthly_periods.unique())
n_splits = self.get_n_splits(X, y, groups)

for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
yield (
idx_train, idx_test
)
train_idx = np.where(monthly_periods == unique_periods[i])[0]
test_idx = np.where(monthly_periods == unique_periods[i + 1])[0]
yield train_idx, test_idx