diff --git a/docs/tutorials/07_gradient_descent.ipynb b/docs/tutorials/07_gradient_descent.ipynb index 6322d2f5..59b91a50 100644 --- a/docs/tutorials/07_gradient_descent.ipynb +++ b/docs/tutorials/07_gradient_descent.ipynb @@ -823,8 +823,7 @@ "for epoch in range(10):\n", " epoch_loss = 0.0\n", " for batch_ind, batch in enumerate(dataloader):\n", - " current_batch = batch[0].numpy()\n", - " label_batch = batch[1].numpy()\n", + " current_batch, label_batch = batch[0]\n", " loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n", " updates, opt_state = optimizer.update(gradient, opt_state)\n", " opt_params = optax.apply_updates(opt_params, updates)\n",