-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAXSCVI multi particles #1385
JAXSCVI multi particles #1385
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1385 +/- ##
==========================================
+ Coverage 90.96% 90.99% +0.02%
==========================================
Files 111 111
Lines 8604 8614 +10
==========================================
+ Hits 7827 7838 +11
+ Misses 777 776 -1
Continue to review full report at Codecov.
|
scvi/model/_jaxscvi.py
Outdated
out = self.bound_module(array_dict) | ||
return out.qz.mean | ||
out = self.bound_module(array_dict, n_samples=mc_samples) | ||
if give_mean: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be careful when using variables within a jit compiled function like this. I think perhaps a better practice here would to be to include give_mean
in the function signature then addition as a static arg if necessary to get it jit compiled. The reason is because if give_mean
is ever changed and _get_val
is called with the same array_dict
it will return a cached value instead of modifying the output as appropriate.
This is just a recommendation for overall code style going forward with the jax code. Nothing forces you to do this in jax, so it's important we're aware of how jit affects the behavior. (it's not so important here, but the we should still stick to the principle!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
scvi/module/_jaxvae.py
Outdated
@@ -73,6 +73,9 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool): | |||
) | |||
h = nn.relu(h) | |||
h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training) | |||
if h.ndim == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is an interesting choice haha.. wonder if there is a more explicit way to bake this into the decoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PierreBoyeau let's instead do the same thing I do at line 69, as this is something that can be done with jax but not pytorch efficiently
scvi/module/_jaxvae.py
Outdated
@@ -73,6 +73,9 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool): | |||
) | |||
h = nn.relu(h) | |||
h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training) | |||
if h.ndim == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@PierreBoyeau let's instead do the same thing I do at line 69, as this is something that can be done with jax but not pytorch efficiently
scvi/model/_jaxscvi.py
Outdated
out = self.bound_module(array_dict) | ||
return out.qz.mean | ||
out = self.bound_module(array_dict, n_samples=mc_samples) | ||
if give_mean: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
Co-authored-by: Pierre Boyeau <[email protected]>
* multi particles * cleanup * release note * add neg bin property Co-authored-by: adamgayoso <[email protected]>
No description provided.