Status of the JAX dynamic API #18476
Unanswered
sergei-mironov
asked this question in
General
Replies: 1 comment 4 replies
-
Hi - I'm a bit confused by what you're asking here. One clarification: it looks like most of your examples fail because they try to create a dynamically-sized array before slicing/indexing it. Is that intentional? For example, this fails:
But that failure has nothing to do with indexing; if you perform the indexing on a statically-sized array, it's fine:
|
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello!
Could you please share your status/plans regarding the JAX dynamic API, covered by this test? I am particularly interested in array indexing operations.
To give some context, I am exploring options for the dynamic array support in the Catalyst quantum compiler where we use jax as a frontend and a modified StableHLO lowering pipeline as a backend. I checked the JAX 0.4.19 against some common indexing use-cases, see the table below. Can you confirm the correctness of the approach in general? Should we expect improvements in this area in the near future?
arr[42]
arr[idx]
arr[0,:,0]
arr[0:42]
arr[0:idx]
jnp.ones([4,4])[0]
a custom runtime
All the use-cases I tried are of the following form:
Where:
jaxpr_to_mlir
is our thin wrapper around lower_jaxpr_to_funI think it is still representative.
compile_and_run_mlir
runs our customized StableHLO lowering to LLVM. It might contain unrelatedissues, I am showing it here for reference.
0.4.19
, commit 604dc24 with a few hacksUpdates
arr[:]
case from the table due to an error in test caseBeta Was this translation helpful? Give feedback.
All reactions