Skip to content

Commit

Permalink
Merge pull request #1100 from YosefLab/scanvi_pred
Browse files Browse the repository at this point in the history
fix scanvi predict typo
  • Loading branch information
adamgayoso authored Jul 15, 2021
2 parents ba22825 + 1d67591 commit f9a0559
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/release_notes/v0.12.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Bug fixes
- Fix issue where anndata setup with a layer led to errors in :class:`~scvi.external.SOLO` (`#1098`_).
- Fix `adata` parameter of :func:`scvi.external.SOLO.from_scvi_model`, which previously did nothing (`#1078`_).
- Fix default `max_epochs` of :class:`~scvi.model.SCANVI` when initializing using pre-trained model of :class:`~scvi.model.SCVI` (`#1079`_).
- Fix bug in `predict()` function of :class:`~scvi.model.SCANVI`, which only occurred for soft predictions (`#1100`_).



Expand Down Expand Up @@ -62,5 +63,5 @@ Contributors
.. _`#1090`: https://github.com/YosefLab/scvi-tools/pull/1090
.. _`#1098`: https://github.com/YosefLab/scvi-tools/pull/1098
.. _`#1099`: https://github.com/YosefLab/scvi-tools/pull/1099

.. _`#1100`: https://github.com/YosefLab/scvi-tools/pull/1100

2 changes: 1 addition & 1 deletion scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def predict(
columns=self._label_mapping[:n_labels],
index=adata.obs_names[indices],
)
return y_pred
return pred

def train(
self,
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import anndata
import numpy as np
import pandas as pd
import pytest
from pytorch_lightning.callbacks import LearningRateMonitor
from scipy.sparse import csr_matrix
Expand Down Expand Up @@ -479,7 +480,8 @@ def test_scanvi(save_path):
predictions = model.predict(adata2, indices=[1, 2, 3])
assert len(predictions) == 3
model.predict()
model.predict(adata2, soft=True)
df = model.predict(adata2, soft=True)
assert isinstance(df, pd.DataFrame)
model.predict(adata2, soft=True, indices=[1, 2, 3])
model.get_normalized_expression(adata2)
model.differential_expression(groupby="labels", group1="label_1")
Expand Down

0 comments on commit f9a0559

Please sign in to comment.