-
I'm having trouble converting a numpy code to Jax and I'm hoping the community can help me out. Here's the numpy code: import numpy as np
t, n = 10, 5
x = np.random.normal(size=(t, n))
r = np.random.uniform(size=(t, t))
z = np.zeros((t, t, n))
for t1 in range(t - 1):
y = x[t1]
z[t1, t1 + 1] = np.copy(y)
for t2 in range(t1 + 1, t - 1):
y = y + x[t2] * (r[t2, t1:t2] @ z[t1, t1:t2])
z[t1, t2 + 1] = np.copy(y)
result = z.sum(axis=-2) I've tried to use @jax.vmap
def f(xi, t1): # This code fails due to t1 going into range
z = jnp.zeros(xi.shape)
for t2 in range(t1+1, t-1):
xi = xi + x[t2] * (r[t2, t1:t2] @ z[t1:t2])
z[t1, t2 + 1] = xi
return z
result = f(x, jnp.arange(t)).sum(axis=-2) Any help would be greatly appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
This is not an easy function to express in terms of JAX transformations or higher-order primitives. You can't use Further, you won't be able to directly express this within That said, you could use this within JAX nearly as-written if you modify it to use JAX's functional array updates: import jax
@jax.jit
def f(x, r, z):
t, n = x.shape
assert r.shape == (t, t)
assert z.shape == (t, t, n)
for t1 in range(t - 1):
y = x[t1]
z = z.at[t1, t1 + 1].set(y)
for t2 in range(t1 + 1, t - 1):
y = y + x[t2] * (r[t2, t1:t2] @ z[t1, t1:t2])
z = z.at[t1, t2 + 1].set(y)
return z.sum(axis=-2)
result_jax = f(x, r, z)
print(result_jax) As long as If that compilation time for the flattened loops is problematic, you best bet is probably to use a nested |
Beta Was this translation helpful? Give feedback.
-
I rewrote the program to one with a simplified loop structure. Given another rainy afternoon, I may finish the rewrite and come up with a fully vectorized version. Yet here is it now. def f_np(x: npt.NDArray, r: npt.NDArray) -> npt.NDArray:
t, n = x.shape
assert r.shape == (t, t)
z = np.zeros((t, t, n))
I, J = np.triu_indices(t, 1)
z[I, J] = x[I]
I, J = np.triu_indices(t - 1, 1)
for i, j in zip(I, J):
z[i, j + 1:] += x[j] * (r[j, :j] @ z[i, :j])
return z.sum(axis=-2) |
Beta Was this translation helpful? Give feedback.
This is not an easy function to express in terms of JAX transformations or higher-order primitives. You can't use
vmap
, because your function is not a purely batch-wise function (the result atz[t1]
depends on allx[t2]
wheret2 > t1
)Further, you won't be able to directly express this within
lax.fori_loop
or similar because each iteration constructs dynamically-shaped intermediate arrays (if you're iterating overt1
andt2
usingfori_loop
, the size ofz[t1:t2]
will be dynamic).That said, you could use this within JAX nearly as-written if you modify it to use JAX's functional array updates: