diff --git a/drawdata/__init__.py b/drawdata/__init__.py index d8ca1c9..c78a50a 100644 --- a/drawdata/__init__.py +++ b/drawdata/__init__.py @@ -44,11 +44,11 @@ def data_as_polars(self): def data_as_X_y(self, kind='classification'): import numpy as np - colors = [_['label'] for _ in self.data] + colors = [_['color'] for _ in self.data] + X = np.array([[_['x']] for _ in self.data]) + + # Assume that we're dealing with regression in this case if np.unique(colors).shape[0] == 1: - X = np.array([[_['x']] for _ in self.data]) y = np.array([_['y'] for _ in self.data]) return X, y - X = np.array([[_['x'], _['y']] for _ in self.data]) - y = np.array([_['label'] for _ in self.data]) - return X, y \ No newline at end of file + return X, colors