Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 117 additions & 13 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@

from sklearn.model_selection import BaseCrossValidator

from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.utils.validation import check_X_y, check_is_fitted, validate_data
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.metrics.pairwise import pairwise_distances


class KNearestNeighbors(BaseEstimator, ClassifierMixin):
class KNearestNeighbors(ClassifierMixin, BaseEstimator):
"""KNearestNeighbors classifier."""

def __init__(self, n_neighbors=1): # noqa: D107
Expand All @@ -82,8 +81,47 @@ def fit(self, X, y):
self : instance of KNearestNeighbors
The current instance of the classifier
"""
X = check_array(X)
check_classification_targets(y)
X, y = check_X_y(X, y)
X, y = validate_data(self, X, y)
self.n_features_in_ = X.shape[1]
self.is_fitted_ = True
self.examples_ = X
self.labels_ = y
self.classes_ = list(np.unique(y))
return self

def euclidian_distance(self, x1, x2, axis=1):
"""Compute euclidian distance for one example.

Parameters
----------
x1 : ndarray, shape (1, n_features)
x2 : ndarray, shape (1, n_features)
Returns
----------
distance : np.float
distances between x1 and x2.
"""
return np.sqrt(np.sum((x1 - x2)**2, axis=axis))

def compute_distance(self, x):
"""Compute all euclidian distances for one example.

Parameters
----------
x : ndarray, shape (1, n_features)
Data to predict on.

Returns
----------
distance : ndarray, shape (1, n_train_samples)
distances between x and each fitted data point.
"""
return self.euclidian_distance(self.examples_,
x.reshape(1, -1)).reshape(1, -1)

def predict(self, X):
"""Predict function.

Expand All @@ -97,7 +135,30 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
check_is_fitted(self)
X = check_array(X)
X = validate_data(self, X, reset=False)
y_pred = np.zeros(X.shape[0], dtype=type(self.labels_[0]))
# Computes distances between each test_datapoint and known datapoints
distance_matrix = np.zeros((X.shape[0], self.examples_.shape[0]))
for test_i in range(X.shape[0]):
distance_matrix[test_i, :] = self.compute_distance(X[test_i, :])
# Finds the n_neighbors nearest points to our each test_example
NEIGHBORS = []
MAX = np.max(distance_matrix)
for neighbor in range(self.n_neighbors):
NEIGHBORS.append(np.argmin(distance_matrix, axis=1))
for row in range(distance_matrix.shape[0]):
min_pos = NEIGHBORS[-1][row]
distance_matrix[row, min_pos] = MAX
# Find label for each test_example
for test_example in range(X.shape[0]):
counts = [0]*len(self.classes_)
for neighbor in range(self.n_neighbors):
i = NEIGHBORS[neighbor][test_example]
neighbor_label = self.labels_[i]
counts[self.classes_.index(neighbor_label)] += 1
y_pred[test_example] = self.classes_[np.argmax(counts)]
return y_pred

def score(self, X, y):
Expand All @@ -115,7 +176,11 @@ def score(self, X, y):
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
X = check_array(X)
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X)
return np.sum(y_pred == y) / y.shape[0]


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +220,28 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
time_data = X[self.time_col] if self.time_col != 'index' else X.index
if not pd.api.types.is_datetime64_any_dtype(time_data):
raise ValueError("Not a datetime column.")
else:
time_data = pd.DatetimeIndex(time_data)
month_data = time_data.month
# print("MONTH",month_data)
year_data = time_data.year
m_y_pairs = list(zip(list(month_data), list(year_data)))
# print("PAIRS", m_y_pairs)
unique_pairs = list(set(m_y_pairs))
# print('UNIQUE_PAIRS',unique_pairs)
split = []
for month, year in unique_pairs:
if (month + 1, year) in unique_pairs:
split.append((month, year, month + 1, year))
if month == 12:
if (1, year + 1) in unique_pairs:
split.append((month, year, 1, year + 1))
split.sort()
self.splits_ = split
return len(split)

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -177,12 +263,30 @@ def split(self, X, y, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""

n_samples = X.shape[0]
n_splits = self.get_n_splits(X, y, groups)
old_order = list(X.index)

# the following is necessary to handle non ordered time data
if self.time_col != 'index':
X = X.sort_values(by=self.time_col)
else:
X = X.sort_index()
new_order = list(X.index)
map_new2old = {new_order[i]: old_order.index(new_order[i])
for i in range(len(old_order))}

time_data = X.index if self.time_col == 'index' else X[self.time_col]

time_data_dti = pd.DatetimeIndex(time_data)
new_order = np.array(new_order)
for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
yield (
idx_train, idx_test
)
month_train, year_train, month_test, year_test = self.splits_[i]
MONTH = time_data_dti.month
YEAR = time_data_dti.year
train_dates = (MONTH == month_train) & (YEAR == year_train)
idx_train = [map_new2old[i] for i in new_order[train_dates]]

test_dates = (MONTH == month_test) & (YEAR == year_test)
idx_test = [map_new2old[i] for i in new_order[test_dates]]

yield (idx_train, idx_test)
Loading