Skip to content

Expected all tensors to be on the same device #341

Answered by SkafteNicki
jaffe-fly asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @jaffe-fly,
Modular metrics are nn.Module and therefore needs to be moved to the same device as the input. This is done automatically if you define them in the __init__ method of your model. You can read more here: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html#metrics-and-devices
However, an maybe easier fix in your case is you could just use the functional version of the F1 metric:

from torchmetrics.functional import f1

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self.forward(x)
    loss = self.loss_fn(logits, y)
    preds = torch.argmax(logits, dim=1)
    acc = accuracy(preds, y)
    f1 = f1(preds, y, num_classes=self.args.num_classes

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
help wanted Extra attention is needed working as intended
2 participants
Converted from issue

This discussion was converted from issue #340 on July 01, 2021 09:15.