Vectorization and Named Arguments in JAX #13090
-
Hello JAX Community, While working on vectorization in JAX, I noticed something confusing and I am here to ask it: I was planing to vectorize
My aim was to make |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hey Göktuğ! Indeed, the function returned by Luckily there's an easy workaround: just write a wrapper which ensures the vmapped function is always called with positional arguments: _vmap_rotate = jax.vmap(pix.rotate, in_axes=(0, 0), out_axes=0) # note leading underscore
def vmap_rotate(images, radians, mode):
return _vmap_rotate(images, radians, mode) Here the WDYT? |
Beta Was this translation helpful? Give feedback.
Hey Göktuğ!
Indeed, the function returned by
vmap
(here namedvmap_rotate
) has an often-surprising interaction with arguments passed as keywords. The reason is ultimately becausein_axes
refers to parameters by position, and we don't attempt to identify those positions with parameters' names (originally because there wasn't a robust mechanism to do that in Python, though it's gotten a bit better with more recent versions of Python, while still imperfect). So instead of arguments passed by keyword being identified with any of thein_axes
, they're just always mapped over their leading axis. (I'd like to revise that behavior, but haven't gotten to it, in part because I know it'll break exist…