Skip to content

How maintain tensor device compatibility in custom metric #1016

Answered by SkafteNicki
digital-idiot asked this question in Q&A
Discussion options

You must be logged in to vote

In this case I assume that the output of

precision(
            preds=preds,
            target=target,
            average='none',
            mdmc_average=self.mdmc_average,
            num_classes=self.num_classes,
            multiclass=self.multiclass,
            ignore_index=self.ignore_index
        )

is the one that is on cpu and self._precision is on the cuda device. torchmetrics.functional.precision should output an tensor on gpu if the input preds and target is also on gpu. Could you check that what you pass to

self.precision.update(...)

is on the gpu?

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@digital-idiot
Comment options

@SkafteNicki
Comment options

@digital-idiot
Comment options

Answer selected by digital-idiot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants