Transposition of ndarray with axes not known at compile time. #22839
-
Hi, I am looking for a function where I have an ndarray and I want to transpose it but its axes are not known at compile time. The documentation for Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This is an interesting question! First of all, keep in mind that because shapes must be known statically, it would only be possible to implement dynamic permutations in the case of an array with all dimensions equal: otherwise the shape of the output array could not be known statically. Now given a multi-dimensional square array, how could you do a dynamic permutation of axes? I don't think there's any really efficient way to do this: XLA's native permutation op requires the axis order to be know statically. But you could implement it via indexing; here's a quick attempt: import itertools
import numpy as np
import jax
import jax.numpy as jnp
@jax.jit
def dynamic_transpose(x, perm):
x = jnp.asarray(x)
assert len(set(x.shape)) == 1, "all dimensions of x must be equal"
perm = jnp.asarray(perm)
assert perm.shape == (x.ndim,), "length of perm must equal number of dimensions of x"
perm = jnp.zeros_like(perm).at[perm].set(jnp.arange(len(perm)))
indices = jnp.stack(jnp.meshgrid(*(jnp.arange(n) for n in x.shape), indexing='ij'))
return x[tuple(indices[perm])]
x = jnp.arange(27).reshape((3, 3, 3))
for perm in itertools.permutations(range(x.ndim)):
np.testing.assert_array_equal(jnp.transpose(x, perm), dynamic_transpose(x, perm)) |
Beta Was this translation helpful? Give feedback.
This is an interesting question! First of all, keep in mind that because shapes must be known statically, it would only be possible to implement dynamic permutations in the case of an array with all dimensions equal: otherwise the shape of the output array could not be known statically.
Now given a multi-dimensional square array, how could you do a dynamic permutation of axes? I don't think there's any really efficient way to do this: XLA's native permutation op requires the axis order to be know statically. But you could implement it via indexing; here's a quick attempt: