diff --git a/CHANGELOG.md b/CHANGELOG.md index 85b0f336de..ff887cd417 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -98,6 +98,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Fixed +- Fix logging of accuracy for cases with 1 sample per class in scANVI {pr}`2938`. - Disable adversarial classifier if training with a single batch. Previously this raised a None error {pr}`2914`. - {meth}`~scvi.model.SCVI.get_normalized_expression` fixed for Poisson distribution and diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 2f111e5dc1..b6eea8ab7e 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -741,7 +741,7 @@ def compute_and_log_metrics( return classification_loss = loss_output.classification_loss - true_labels = loss_output.true_labels.squeeze() + true_labels = loss_output.true_labels.squeeze(-1) logits = loss_output.logits predicted_labels = torch.argmax(logits, dim=-1)