How to efficiently loop over classes in JAX? #14435
Replies: 1 comment
-
What you're describing is an array-of-structs pattern. JAX is built for working with a struct-of-arrays pattern. The best way to do what you're asking in JAX is to change your code to the latter. Here's a brief example; imagine you have a from typing import NamedTuple
import jax
import jax.numpy as jnp
class Point(NamedTuple):
x: jax.Array
y: jax.Array An array-of-structs approach to a collection of points would look like this: ponits = [Point(x, y) for x, y in zip(jnp.arange(10), jnp.ones(10))] A struct-of-arrays approach to a collection of points would look like this: points = Point(jnp.arange(10), jnp.ones(10)) The latter can be used with JAX functions directly; e.g. @jax.vmap
def f(point):
return jnp.hypot(point.x, point.y)
print(f(points))
The same computation on an array-of-structs pattern would be far less efficient, both in terms of how it is expressed in code and in computation time. So to answer your question: if you have a list of class objects that you want to use with JAX, you shoud change your code to instead use a single class whose attributes are JAX arrays (and make sure that class is a pytree so it works with JAX transforms). |
Beta Was this translation helpful? Give feedback.
-
Hi guys,
I am trying to implement an Unsteady Vortex Lattice Method solver in JAX.
First, I followed the Google Brax pytree rationale and have a class, i.e "panel", where I stored all relevant data as jnp.ndarray. But how I can efficiently loop over N pytree classes since I can't vmap it ?
Thanks in advance,
Eduardo
Beta Was this translation helpful? Give feedback.
All reactions