You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working on an application with normalizing flows which requires sampling constrained values for some dimensions of the flow.
I found that the dimension verification to be a bit unpredictable in distrax.Transformed.
For example, a simple a sigmoid transformation of a multivariate normal works well using tfp jax substrate but doesn't work with distrax:
fromjaximportnumpyasjnpimporthaikuashkimportdistraxfromtensorflow_probability.substratesimportjaxastfptfd=tfp.distributionstfb=tfp.bijectorsprng_seq=hk.PRNGSequence(123)
event_shape= (3,)
base_dist=tfd.MultivariateNormalDiag(
loc=jnp.zeros(event_shape),
scale_diag=jnp.ones(event_shape),
)
q_distr=tfd.TransformedDistribution(base_dist, tfb.Sigmoid())
q_distr.sample(seed=next(prng_seq)) # All good :)base_dist=distrax.MultivariateNormalDiag(
loc=jnp.zeros(event_shape),
scale_diag=jnp.ones(event_shape),
)
q_distr=distrax.Transformed(base_dist, distrax.Sigmoid()) # Doesn't work :(# ValueError: Base distribution 'MultivariateNormalDiag' has event shape (3,), but bijector 'Sigmoid' expects events to have 0 dimensions. Perhaps use `distrax.Block` or `distrax.Independent`?q_distr=distrax.Transformed(base_dist, tfb.Exp()) # Doesn't work either :(
Would be very nice to have this supported by distrax.
Thank you!!!
All best,
Chris
The text was updated successfully, but these errors were encountered:
Thank you for your comment. This was an intentional design choice for distrax to make bijectors closer to their mathematical definition. Note that distrax.Sigmoid is a bijector that transforms a scalar value x into another scalar value y by applying the transformation y = sigmoid(x). Since the bijector acts on scalars, having a 3-dimensional input vector throws an error.
To avoid the error, you need a bijector that transforms a vectorx into a vectory by applying the transformation y_i = sigmoid(x_i) for each component i of the vector. To obtain such bijector, as suggested by the error message, you can use distrax.Block(). For example, this code snippet should work:
Thanks so much for the quick and insightful reply.
I agree, making this explicit declaration of Block brings clarity in many cases. I'll modify my scripts as suggested.
Hi,
I'm working on an application with normalizing flows which requires sampling constrained values for some dimensions of the flow.
I found that the dimension verification to be a bit unpredictable in
distrax.Transformed
.For example, a simple a sigmoid transformation of a multivariate normal works well using tfp jax substrate but doesn't work with distrax:
Would be very nice to have this supported by distrax.
Thank you!!!
All best,
Chris
The text was updated successfully, but these errors were encountered: