Correct way to implement custom metrics - Tensors keep filling up #390
-
I wrote a custom class BrierScore(torchmetrics.Metric):
def __init__(self, cfg, mode='logit'):
super().__init__()
self.cfg = cfg
self.mode = mode
if self.mode == 'logit':
self.pred_transform = lambda x: torch.nn.functional.softmax(x, dim=1)
else:
self.pred_transform = lambda x: x
self.add_state("preds", default=Tensor(), dist_reduce_fx="cat")
self.add_state("target", default=Tensor(), dist_reduce_fx="cat")
def update(self, preds: Tensor, target: Tensor) -> None:
preds = self.pred_transform(preds.detach())
self.preds = torch.cat([self.preds, preds])
self.target = torch.cat([self.target, torch.nn.functional.one_hot(target, num_classes=self.cfg.DATA.NUM_CLASSES)])
def compute(self):
score = (self.preds - self.target)**2
score = score.sum()/score.shape[0]
return score |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi, First, those are only emptied after a call to Second: I would suggest to use lists for you state, since you are concatenating tensors which can be a memory allocating operation (not necessary at that time). So probably you could go with |
Beta Was this translation helpful? Give feedback.
Hi,
First, those are only emptied after a call to
.reset()
(which is automated within lightning), not after every call tocompute
.Second: I would suggest to use lists for you state, since you are concatenating tensors which can be a memory allocating operation (not necessary at that time). So probably you could go with
self.add_state("preds", default=[], dist_reduce_fx="cat")
andself.preds.append(preds)
instead.