How to calculate across dist that do not reduce? #549
Unanswered
jmerkow
asked this question in
CompVision
Replies: 1 comment
-
Hi @jmerkow, thanks for the question. class TopkLossImageMetric(Metric):
def __init__(self, k, largest, img_size):
super().__init__()
self.k = k
self.largest = largest
# setting dist_reduce_fx="cat" will concatenate the result from the different
# processes when .compute is called
self.add_state("loss", torch.zeros(k), dist_reduce_fx="cat")
self.add_state("images", torch.zeros(k, *img_size), dist_reduce_fx="cat")
self.add_state("recon", torch.zeros(k), dist_reduce_fx="cat") # I assume recon are scalars?
def update(self, loss, images, recon):
out = self._collect(self.loss, self.image, self.recon, loss, images, recon, self.k, self.largest)
self.loss = out[0]
self.images = out[1]
self.recon = out[2]
def compute(self):
loss, idx = torch.topk(self.loss, min(self.k, self.loss.shape[0]), dim=0, largest=self.largest)
return loss, self.images[idx.flatten()], self.recon[idx.flatten()]
@staticmethod
def _collect(prev_loss, prev_imgs, prev_recon, loss, images, recon, k, largest):
loss = torch.cat((prev_loss, loss), dim=0)
images = torch.cat((prev_imgs, images), dim=0)
recon = torch.cat((prev_recon, recon), dim=0)
loss, idx = torch.topk(loss, min(k, loss.shape[0]), dim=0, largest=largest)
return loss, images[idx.flatten()], recon[idx.flatten()] the metric has three states that we keep updating whenever |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am working on moving an autoencoder metric to TorchMetric and I'm having some issues figuring out how to approach it. The metric keeps track of the top K and bottom K recon losses through validation, then writes them to an image during on_validation_epoch_end. Essentially I calculate loss for all items in the batch, cat those with the current min/max, and use torch.topk. Here is a snippet of the basic operation I am trying to do.
I am struggling to figure out how to do this operation with ddp. Reading the documentation, its not clear how to get into the gathered state so that topk would truly be across the entire batch. If someone can point me to example that could help it would be greatly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions