Skip to content

Commit

Permalink
Append val/test output to instance variable in EncDecSpeakerLabelModel (
Browse files Browse the repository at this point in the history
#7562) (#7573)

* Append val/test output to the instance variable in EncDecSpeakerLabelModel



* Handle test case in evaluation_step



* Replace type with isinstance



---------

Signed-off-by: Abhishree <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
  • Loading branch information
github-actions[bot] and athitten committed Sep 29, 2023
1 parent 6560470 commit c1be0e7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
10 changes: 10 additions & 0 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True)

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
return output_dict

@classmethod
Expand Down
14 changes: 13 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,25 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self._macro_accuracy.update(preds=logits, target=labels)
stats = self._macro_accuracy._final_state()

return {
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro_top_k': acc_top_k,
f'{tag}_acc_macro_stats': stats,
}
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)

return output

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
Expand Down

0 comments on commit c1be0e7

Please sign in to comment.