Skip to content

Conversation

copybara-service[bot]
Copy link

[JAX] Add a literal array type distinct from NumPy arrays.

When JAX captures a NumPy array or Python array during tracing, it canonicalizes it, which squashes 64-bit types to 32 bits if enable_x64=False.

However, tracing is not something JAX does once: for example, we often do things such as call eval_jaxpr on a previously traced jaxpr to transform it. As part of work to enable x64 types in JAX more generally, we want to avoid repeated canonicalization: we want to canonicalize constants exactly once when they are first traced, and never if they are retraced perhaps in a different -x64 context.

Further, we need a way to pass numpy arrays to jit and device_put without having them be canonicalized. LiteralArray plays that role, and it is key to, for example, making code like jnp.array(..., dtype=jnp.int64) work even in a non-x64 mode.

LiteralArray has the following properties:

  • it wraps a NumPy array
  • it duck types as a NumPy array (necessary because users may see the type, e.g., in a custom_jvp rule, for example).
  • it has a weak_type attribute, so it is capable of representing the same set of array types a jax.Array is capable of representing
  • canonicalizing a Python scalar or NumPy array forms a LiteralArray. In the case of a Python scalar, we form a weak-typed LiteralArray.
  • LiteralArray itself is unchanged by canonicalization. Once we have chosen a type for a LiteralArray, it will not be changed by any later canonicalization.

This change is not entirely safe and requires some updates to user code. Notably, LiteralArray types do sometimes leak out of JAX API boundaries, and in a few cases this required updates to code that uses constructs like isinstance(x, np.ndarray).

@copybara-service copybara-service bot force-pushed the test_805425037 branch 6 times, most recently from 11ee642 to 8975c1e Compare September 11, 2025 18:34
When JAX captures a NumPy array or Python array during tracing, it canonicalizes it, which squashes 64-bit types to 32 bits if `enable_x64=False`.

However, tracing is not something JAX does once: for example, we often do things such as call `eval_jaxpr` on a previously traced jaxpr to transform it. As part of work to enable x64 types in JAX more generally, we want to avoid repeated canonicalization: we want to canonicalize constants exactly once when they are first traced, and never if they are retraced perhaps in a different -x64 context.

Further, we need a way to pass numpy arrays to `jit` and `device_put` without having them be canonicalized. `LiteralArray` plays that role, and it is key to, for example, making code like `jnp.array(..., dtype=jnp.int64)` work even in a non-x64 mode.

`LiteralArray` has the following properties:
* it wraps a NumPy array
* it duck types as a NumPy array (necessary because users may see the type, e.g., in a `custom_jvp` rule, for example).
* it has a weak_type attribute, so it is capable of representing the same set of array types a `jax.Array` is capable of representing
* canonicalizing a Python scalar or NumPy array forms a LiteralArray. In the case of a Python scalar, we form a weak-typed `LiteralArray`.
* LiteralArray itself is unchanged by canonicalization. Once we have chosen a type for a LiteralArray, it will not be changed by any later canonicalization.

This change is not entirely safe and requires some updates to user code. Notably, `LiteralArray` types do sometimes leak out of JAX API boundaries, and in a few cases this required updates to code that uses constructs like `isinstance(x, np.ndarray)`.

PiperOrigin-RevId: 805425037
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant