How maintain tensor device compatibility in custom metric #1016
-
Below is an example implementation of custom metric which inherits from class ClassPrecision(Metric):
# noinspection SpellCheckingInspection
def __init__(
self,
num_classes: int,
ignore_index: int = None,
multiclass: bool = True,
mdmc_average: str = 'global',
**kwargs: Any
):
super(ClassPrecision, self).__init__(**kwargs)
self.num_classes = num_classes
self.ignore_index = ignore_index
self.multiclass = multiclass
self.mdmc_average = mdmc_average
self.add_state(
name="_count",
default=torch.tensor(data=0.0, dtype=torch.get_default_dtype()),
dist_reduce_fx="sum"
)
self.add_state(
name='_precision',
default=torch.zeros(
self.num_classes,
dtype=torch.get_default_dtype()
),
dist_reduce_fx="sum"
)
# noinspection SpellCheckingInspection
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
self._count += 1.0
self._precision += precision(
preds=preds,
target=target,
average='none',
mdmc_average=self.mdmc_average,
num_classes=self.num_classes,
multiclass=self.multiclass,
ignore_index=self.ignore_index
)
def compute(self) -> Any:
return self._precision / self._count I use this in my class LightningWrapper(LightningModule):
def __init__(self, ...)
...
self.precision = ClassPrecision(...)
def training_step(self, ...)
...
self.precision.update(...) However, I get: self._precision += precision(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! How to make sure all the tensors in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
In this case I assume that the output of precision(
preds=preds,
target=target,
average='none',
mdmc_average=self.mdmc_average,
num_classes=self.num_classes,
multiclass=self.multiclass,
ignore_index=self.ignore_index
) is the one that is on self.precision.update(...) is on the gpu? |
Beta Was this translation helpful? Give feedback.
In this case I assume that the output of
is the one that is on
cpu
andself._precision
is on thecuda
device.torchmetrics.functional.precision
should output an tensor on gpu if the inputpreds
andtarget
is also on gpu. Could you check that what you pass tois on the gpu?