Skip to content
This repository was archived by the owner on Jul 14, 2024. It is now read-only.

Change scipy.special.logit into jnp.log #57

Merged
merged 1 commit into from
May 24, 2022
Merged

Change scipy.special.logit into jnp.log #57

merged 1 commit into from
May 24, 2022

Conversation

petergchang
Copy link
Contributor

IIUC, the logits argument for jax.random.categorical correspond to log probabilities instead of their logit values
(defined by log(p/(1-p))),
and so the usage of jax.random.categorical in hmm_forwards_filtering_backwards_sampling_jax and hmm_sample_jax should be modified to take in
jnp.log(.) instead of logit(.).

From the documentation for jax.random.catgegorical:

logits - Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

IIUC, the `logits` argument for `jax.random.categorical` take in log probabilities instead of their logit values (defined by log(p/(1-p))),
and so the usage of `jax.random.categorical` in `hmm_forwards_filtering_backwards_sampling_jax` and `hmm_sample_jax` should be modified to take in
jnp.log(.) instead of logit(.).
@murphyk
Copy link
Member

murphyk commented May 24, 2022

Yes, you are right. thanks.

@murphyk murphyk merged commit 18c78a4 into probml:main May 24, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants