-
I'm looking to learn correlation matrices (positive semi-definite) reparametrized as L_flat = DenseLayer(ndim=3*3)(input)
L = L_flat.reshape(3, 3)
return L @ L.T but this is wasteful, since it is sufficient to paramaterize the triangular part of L in order to cover all possible correlation matrices. In our 3x3 case it's sufficient to have a |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 4 replies
-
This is what I've been using, it basically uses indexing as you said, but is not too slow for my purposes. def fill_lower_tri(v, dim, out_dtype=float):
num_nonzero = dim * (dim - 1) // 2
mask = jnp.tri(dim, dtype=bool, k=-1)
mask_idx = jnp.nonzero(mask, size=num_nonzero)
out = jnp.eye(dim, dtype=out_dtype).at[mask_idx].set(v)
return out |
Beta Was this translation helpful? Give feedback.
-
TensorFlow Probability (on JAX) has a bijector meant for this application ( The implementation is based on the Pasted from x = np.arange(15) + 1
xc = np.concatenate([x, x[5:][::-1]])
# ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
# 12, 11, 10, 9, 8, 7, 6])
# (We add one to the arange result to disambiguate the zeros below the
# diagonal of our upper-triangular matrix from the first entry in `x`.)
# Now, when reshapedlay this out as a matrix:
y = np.reshape(xc, [5, 5])
# ==> array([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [15, 14, 13, 12, 11],
# [10, 9, 8, 7, 6]])
# Finally, zero the elements below the diagonal:
y = np.triu(y, k=0)
# ==> array([[ 1, 2, 3, 4, 5],
# [ 0, 7, 8, 9, 10],
# [ 0, 0, 13, 14, 15],
# [ 0, 0, 0, 12, 11],
# [ 0, 0, 0, 0, 6]]) |
Beta Was this translation helpful? Give feedback.
-
Another alternative from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
@partial(jax.jit, static_argnames='dim')
def fill_lower_tri(v, dim):
# we can use jax.ensure_compile_time_eval + jnp.tri to do mask indexing
# but best practice is use numpy for static variable
# and jnp.tril_indices is just a wrapper around np.tril_indices
idx = np.tril_indices(dim)
return jnp.zeros((dim, dim), dtype=v.dtype).at[idx].set(v)
print(fill_lower_tri(jnp.arange(6), 3)) |
Beta Was this translation helpful? Give feedback.
-
Thanks everyone for suggestions, since I can only accept one answer I chose that of @sharadmv since this is the approach I didn't think of myself :) |
Beta Was this translation helpful? Give feedback.
-
Use vmapOn May 23, 2023 1:28 PM, Dishank Bansal ***@***.***> wrote:
How can I make the function handle batches? ie where v is of shape (N, ....) where is N is batch length.
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
TensorFlow Probability (on JAX) has a bijector meant for this application (
tfb.FillScaleTriL
). It will convert a vector of unconstrained values into a PSD matrix (and it can also convert the PSD matrix back into the vector if you want).The implementation is based on the
fill_triangular
function that concatenates a vector to the "tail" of itself, reshapes, then zeros out half of the matrix.Pasted from
fill_triangular
's docstring: