Skip to content

Transposition of ndarray with axes not known at compile time. #22839

Answered by jakevdp
erick-xanadu asked this question in Q&A
Discussion options

You must be logged in to vote

This is an interesting question! First of all, keep in mind that because shapes must be known statically, it would only be possible to implement dynamic permutations in the case of an array with all dimensions equal: otherwise the shape of the output array could not be known statically.

Now given a multi-dimensional square array, how could you do a dynamic permutation of axes? I don't think there's any really efficient way to do this: XLA's native permutation op requires the axis order to be know statically. But you could implement it via indexing; here's a quick attempt:

import itertools
import numpy as np
import jax
import jax.numpy as jnp

@jax.jit
def dynamic_transpose(x, perm):
  x = j…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by erick-xanadu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants