You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi @gabrieltseng, I've read your paper and find it a really interesting work!
Thanks a lot for sharing your code as well!
I'm trying to adapt your downstream task notebook for finetuning the pretrained Presto model on the same dataset used in the notebook.
My approach is based on the README instructions, the code in the notebook and the functions evaluate and finetune found in cropharvest_eval.py. The main part of my code is:
## using Presto for finetuning
# based on functions eval and finetune in presto/eval/cropharvest_eval.py
pretrained_model = presto.Presto.load_pretrained()
print(type(pretrained_model))
pretrained_model.eval()
# build finetuning model: encoder + linear transformation (FinetuningHead)
num_outputs = 2
regression = False
finetuning_model = pretrained_model.construct_finetuning_model(
num_outputs=num_outputs,
regression=regression,
)
opt = Adam(finetuning_model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss(reduction="mean")
# train finetuning model
max_epochs = 5
train_loss = []
for epoch in range(max_epochs):
print(f"Training for epoch: {(epoch+1):03}")
finetuning_model.train()
epoch_train_loss = 0.0
for (x, mask, dw, latlons, y, month) in tqdm(train_dl):
# zero the parameter gradients
opt.zero_grad()
# forward + backward + optimize
preds = finetuning_model(
x,
dynamic_world=dw,
mask=mask,
latlons=latlons,
month=month,
)
loss = loss_fn(preds, y.type(torch.LongTensor))
epoch_train_loss += loss.item()
loss.backward()
opt.step()
train_loss.append(epoch_train_loss / len(train_dl))
# make predictions using finetuning model
test_preds = []
for (x, mask, dw, latlons, month) in tqdm(test_dl):
x = x.to(device)
dw = dw.to(device).long()
mask = mask.to(device)
latlons = latlons.to(device).float()
month = month.to(device)
with torch.no_grad():
finetuning_model.eval()
preds = (
finetuning_model(
x, dynamic_world=dw, mask=mask, latlons=latlons, month=month
)
.cpu()
.numpy()
)
# preds = np.argmax(preds, axis=-1)
test_preds.append(preds)
print("predicting with finetuning model...")
print(len(test_preds))
print(test_preds[0])
And from the print outputs I see for example that the predictions in test_preds[0] are:
and I get similar numbers for the remaining elements in test_preds.
But if these numbers are predictions I would expect them to be probabilities that sum up to 1, or that should not be the case here?
I guess there may be some step I'm missing but I can't figure out what it could be.
Could you please give any hint on this? I would really appreciate your help.
Cheers,
Hugo
The text was updated successfully, but these errors were encountered:
The output of the model contains the un-normalized logits for each class (which is also the input expected by the nn.CrossEntropyLoss). If you want to go from the un-normalized logits to something which more resembles a probability, I recommend applying a softmax to the output.
However, if you are only predicting two classes then is this a binary classification problem? If so, you can have a single output (num_outputs = 1) (in which case a sigmoid activation is automatically applied by the model) and train the model using nn.BCELoss. You can then interpret the outputs as the probability of the positive class.
Hi @gabrieltseng, I've read your paper and find it a really interesting work!
Thanks a lot for sharing your code as well!
I'm trying to adapt your downstream task notebook for finetuning the pretrained Presto model on the same dataset used in the notebook.
My approach is based on the README instructions, the code in the notebook and the functions evaluate and finetune found in cropharvest_eval.py. The main part of my code is:
And from the print outputs I see for example that the predictions in test_preds[0] are:
predicting with finetuning model...
53
[[-9.505264 , 9.816492 ],
[-9.501129 , 9.811971 ],
[-9.496433 , 9.806909 ],
[-9.49617 , 9.806579 ],
[-9.495665 , 9.805991 ],
[-9.4937105, 9.803866 ],
[-9.497982 , 9.808611 ],
[-9.507018 , 9.818317 ],
[-9.520019 , 9.832625 ],
[-9.512251 , 9.824137 ],
...
[-9.4941025, 9.804452 ],
[-9.506046 , 9.817224 ],
[-9.48634 , 9.795685 ],
[-9.496958 , 9.807267 ]]
and I get similar numbers for the remaining elements in test_preds.
But if these numbers are predictions I would expect them to be probabilities that sum up to 1, or that should not be the case here?
I guess there may be some step I'm missing but I can't figure out what it could be.
Could you please give any hint on this? I would really appreciate your help.
Cheers,
Hugo
The text was updated successfully, but these errors were encountered: