Replies: 1 comment
-
figured the issue out. it has to do with the fact that you're in interference_mode() as well as using .detach() on the tensors you are getting from test_loss. even though the gradients aren't turned on in the first place so it's causing some kind of weird interaction i think. getting rid of .detach() and just using .numpy() fixes the issue for me. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Here is some code for context (you can also look here for more context: https://www.learnpytorch.io/01_pytorch_workflow/, scroll down a bit). I bolded (double-asterix) the line of code I was having trouble understanding. The comment tries to explain why .type(torch.float) was used and I realized it is necessary (otherwise the subsequent plot breaks, but it still doesn't make any sense to me why it is necessary. I checked the datatype for both test_pred and y_test and they are both torch.float32 even without doing this operation. I compared the test_loss_values list and train_loss_values list without doing this operation and compared it to when you do include this operation (.type(torch.float)). I don't see a difference, yet when you make the plot not including this operation it breaks. Can someone please help me understand. Thank you.
with torch.inference_mode():
# 1. Forward pass on test data
test_pred = model_0(X_test)
plt.plot(epoch_count, train_loss_values, label="Train loss")
plt.plot(epoch_count, test_loss_values, label="Test loss")
plt.title("Training and test loss curves")
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend();
Beta Was this translation helpful? Give feedback.
All reactions