diff --git a/qiskit_dynamics/dispatch/backends/jax.py b/qiskit_dynamics/dispatch/backends/jax.py index 6e6d141e1..24a3eba09 100644 --- a/qiskit_dynamics/dispatch/backends/jax.py +++ b/qiskit_dynamics/dispatch/backends/jax.py @@ -24,6 +24,14 @@ JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer) + try: + # This class was introduced in 0.4.0 + from jax import Array + + JAX_TYPES += (Array,) + except ImportError: + pass + try: # This class is not in older versions of Jax from jax.interpreters.partial_eval import DynamicJaxprTracer diff --git a/releasenotes/notes/jax-4-compatibility-8e3398e95f758dfe.yaml b/releasenotes/notes/jax-4-compatibility-8e3398e95f758dfe.yaml new file mode 100644 index 000000000..e71fbd43d --- /dev/null +++ b/releasenotes/notes/jax-4-compatibility-8e3398e95f758dfe.yaml @@ -0,0 +1,4 @@ +--- +upgrade: + - | + The ``jax.Array`` class has been added to the dispatcher for compatibility with JAX 0.4.0. \ No newline at end of file