Skip to content

Commit

Permalink
Append val/test output to instance variable in EncDecSpeakerLabelModel (
Browse files Browse the repository at this point in the history
#7562)

* Append val/test output to the instance variable in EncDecSpeakerLabelModel

Signed-off-by: Abhishree <[email protected]>

* Handle test case in evaluation_step

Signed-off-by: Abhishree <[email protected]>

* Replace type with isinstance

Signed-off-by: Abhishree <[email protected]>

---------

Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten committed Sep 29, 2023
1 parent 0bc7e5b commit eccf98f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,16 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True)

if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
return output_dict

@classmethod
Expand Down
14 changes: 13 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,25 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self._macro_accuracy.update(preds=logits, target=labels)
stats = self._macro_accuracy._final_state()

return {
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro_top_k': acc_top_k,
f'{tag}_acc_macro_stats': stats,
}
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)

return output

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
Expand Down

0 comments on commit eccf98f

Please sign in to comment.