From 6495162aaa47c4320a813a1f60fce1d72291bba2 Mon Sep 17 00:00:00 2001 From: jnsbck <65561470+jnsbck@users.noreply.github.com> Date: Mon, 4 Nov 2024 13:58:12 +0100 Subject: [PATCH] Remove tensorflow dependency. (#484) * enh: remove tensorflow dependency. closes #481 * fix: ammend last commit * fix: change params * fix: checkout tutorial 07 on main and apply changes to reduce diff --- docs/tutorials/07_gradient_descent.ipynb | 34 ++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/docs/tutorials/07_gradient_descent.ipynb b/docs/tutorials/07_gradient_descent.ipynb index 061a1d4b..6322d2f5 100644 --- a/docs/tutorials/07_gradient_descent.ipynb +++ b/docs/tutorials/07_gradient_descent.ipynb @@ -755,8 +755,33 @@ "metadata": {}, "outputs": [], "source": [ - "import tensorflow as tf\n", - "from tensorflow.data import Dataset" + "class Dataset:\n", + " def __init__(self, inputs: np.ndarray, labels: np.ndarray):\n", + " \"\"\"Simple Dataloader.\n", + " \n", + " Args:\n", + " inputs: Array of shape (num_samples, num_dim)\n", + " labels: Array of shape (num_samples,)\n", + " \"\"\"\n", + " assert len(inputs) == len(labels), \"Inputs and labels must have same length\"\n", + " self.inputs = inputs\n", + " self.labels = labels\n", + " self.num_samples = len(inputs)\n", + " \n", + " def shuffle(self, seed=None):\n", + " \"\"\"Shuffle the dataset in-place\"\"\"\n", + " if seed is not None:\n", + " np.random.seed(seed)\n", + " indices = np.random.permutation(self.num_samples)\n", + " self.inputs = self.inputs[indices]\n", + " self.labels = self.labels[indices]\n", + " return self\n", + " \n", + " def batch(self, batch_size):\n", + " \"\"\"Create batches of the data\"\"\"\n", + " for start in range(0, self.num_samples, batch_size):\n", + " end = min(start + batch_size, self.num_samples)\n", + " yield self.inputs[start:end], self.labels[start:end]" ] }, { @@ -768,9 +793,8 @@ "source": [ "batch_size = 4\n", "\n", - "tf.random.set_seed(1)\n", - "dataloader = Dataset.from_tensor_slices((inputs, labels))\n", - "dataloader = dataloader.shuffle(dataloader.cardinality()).batch(batch_size)" + "dataloader = Dataset(inputs, labels)\n", + "dataloader = dataloader.shuffle(seed=1).batch(batch_size)" ] }, {