-
Hi. I'm running a jit compiled fn with pmap in a distributed fashion with different seeds. The compiled fn returns a python dict. It works great but pmap returns one big dict with concatenated arrays inside, instead of a list of individual dicts, where each dict corresponds to one of the parallel task. The structure of the returned dictionary is somewhat complicated, it contains multiple arrays and other nested dicts with arrays, so splitting it in post-processing is cumbersome and error-prone. Is there a way to make pmap output a list of dicts instead of one big dict with concatenated arrays. Alternatively, is there a way to make it keep arrays from parallel tasks in collections of some sort that I can index into (as opposed to concatenating them)? Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
JAX vectorization and parallelization transformations, including For example, imagine your import jax
import jax.numpy as jnp
result = {'a': jnp.arange(4), 'b': {'c': jnp.ones((4, 2)), 'd': jnp.zeros((4, 7))}}
jax.tree.transpose(jax.tree.structure(result), None, jax.tree.map(list, result))
That said, you might think about whether this is what you really want: after all, any other JAX-transformed operation you do on this will also expect the data to be in its original struct-of-arrays form, so you may find this array-of-structs version harder to work with. If instead you want to extract one batch at a time, you can do so using jax.tree.map(lambda x: x[1], result)
|
Beta Was this translation helpful? Give feedback.
JAX vectorization and parallelization transformations, including
vmap
,pmap
, andshard_map
, all work with a struct-of-arrays storage pattern rather than an array-of-structs pattern. This means that if youvmap
orpmap
over a dict, you get a dict of batched arrays, not a list of dicts. There is no way to make these transformations return an array of structs, but you could take the output and transform it as a post-processing step.For example, imagine your
pmap
created this nested dict of arrays with a leading batch dimension of size 4; you could usejax.tree.transpose
to convert it to a sequence of dicts: