[JAX] Add a literal array type distinct from NumPy arrays. #2369
+7
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[JAX] Add a literal array type distinct from NumPy arrays.
When JAX captures a NumPy array or Python array during tracing, it canonicalizes it, which squashes 64-bit types to 32 bits if
enable_x64=False
.However, tracing is not something JAX does once: for example, we often do things such as call
eval_jaxpr
on a previously traced jaxpr to transform it. As part of work to enable x64 types in JAX more generally, we want to avoid repeated canonicalization: we want to canonicalize constants exactly once when they are first traced, and never if they are retraced perhaps in a different -x64 context.Further, we need a way to pass numpy arrays to
jit
anddevice_put
without having them be canonicalized.LiteralArray
plays that role, and it is key to, for example, making code likejnp.array(..., dtype=jnp.int64)
work even in a non-x64 mode.LiteralArray
has the following properties:custom_jvp
rule, for example).jax.Array
is capable of representingLiteralArray
.This change is not entirely safe and requires some updates to user code. Notably,
LiteralArray
types do sometimes leak out of JAX API boundaries, and in a few cases this required updates to code that uses constructs likeisinstance(x, np.ndarray)
.