From 5ef97f52ea2c8f6571c1c9ad0bf54d5beae8825d Mon Sep 17 00:00:00 2001 From: vincent d warmerdam Date: Tue, 14 May 2024 16:39:59 +0200 Subject: [PATCH] Update __init__.py --- drawdata/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/drawdata/__init__.py b/drawdata/__init__.py index d8ca1c9..c78a50a 100644 --- a/drawdata/__init__.py +++ b/drawdata/__init__.py @@ -44,11 +44,11 @@ def data_as_polars(self): def data_as_X_y(self, kind='classification'): import numpy as np - colors = [_['label'] for _ in self.data] + colors = [_['color'] for _ in self.data] + X = np.array([[_['x']] for _ in self.data]) + + # Assume that we're dealing with regression in this case if np.unique(colors).shape[0] == 1: - X = np.array([[_['x']] for _ in self.data]) y = np.array([_['y'] for _ in self.data]) return X, y - X = np.array([[_['x'], _['y']] for _ in self.data]) - y = np.array([_['label'] for _ in self.data]) - return X, y \ No newline at end of file + return X, colors