Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAXSCVI multi particles #1385

Merged
merged 6 commits into from
Mar 7, 2022
Merged

JAXSCVI multi particles #1385

merged 6 commits into from
Mar 7, 2022

Conversation

PierreBoyeau
Copy link
Contributor

No description provided.

@codecov
Copy link

codecov bot commented Mar 1, 2022

Codecov Report

Merging #1385 (e8c4267) into master (3c838fa) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1385      +/-   ##
==========================================
+ Coverage   90.96%   90.99%   +0.02%     
==========================================
  Files         111      111              
  Lines        8604     8614      +10     
==========================================
+ Hits         7827     7838      +11     
+ Misses        777      776       -1     
Impacted Files Coverage Δ
scvi/distributions/_negative_binomial.py 88.77% <100.00%> (+0.72%) ⬆️
scvi/model/_jaxscvi.py 93.43% <100.00%> (+0.19%) ⬆️
scvi/module/_jaxvae.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3c838fa...e8c4267. Read the comment docs.

@PierreBoyeau PierreBoyeau changed the title multi particles JAXSCVI multi particles Mar 1, 2022
out = self.bound_module(array_dict)
return out.qz.mean
out = self.bound_module(array_dict, n_samples=mc_samples)
if give_mean:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be careful when using variables within a jit compiled function like this. I think perhaps a better practice here would to be to include give_mean in the function signature then addition as a static arg if necessary to get it jit compiled. The reason is because if give_mean is ever changed and _get_val is called with the same array_dict it will return a cached value instead of modifying the output as appropriate.

This is just a recommendation for overall code style going forward with the jax code. Nothing forces you to do this in jax, so it's important we're aware of how jit affects the behavior. (it's not so important here, but the we should still stick to the principle!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

@@ -73,6 +73,9 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool):
)
h = nn.relu(h)
h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training)
if h.ndim == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an interesting choice haha.. wonder if there is a more explicit way to bake this into the decoder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PierreBoyeau let's instead do the same thing I do at line 69, as this is something that can be done with jax but not pytorch efficiently

@@ -73,6 +73,9 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool):
)
h = nn.relu(h)
h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training)
if h.ndim == 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PierreBoyeau let's instead do the same thing I do at line 69, as this is something that can be done with jax but not pytorch efficiently

out = self.bound_module(array_dict)
return out.qz.mean
out = self.bound_module(array_dict, n_samples=mc_samples)
if give_mean:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

@adamgayoso adamgayoso enabled auto-merge (squash) March 7, 2022 18:04
@adamgayoso adamgayoso merged commit dac2589 into master Mar 7, 2022
meeseeksmachine pushed a commit to meeseeksmachine/scvi-tools that referenced this pull request Mar 7, 2022
adamgayoso pushed a commit that referenced this pull request Mar 7, 2022
@adamgayoso adamgayoso deleted the multi_particles branch March 8, 2022 02:12
nrclaudio pushed a commit to nrclaudio/scvi-tools-tune that referenced this pull request Jun 21, 2022
* multi particles

* cleanup

* release note

* add neg bin property

Co-authored-by: adamgayoso <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants