Skip to content

Vectorization and Named Arguments in JAX #13090

Answered by mattjj
GoktugGuvercin asked this question in Q&A
Discussion options

You must be logged in to vote

Hey Göktuğ!

Indeed, the function returned by vmap (here named vmap_rotate) has an often-surprising interaction with arguments passed as keywords. The reason is ultimately because in_axes refers to parameters by position, and we don't attempt to identify those positions with parameters' names (originally because there wasn't a robust mechanism to do that in Python, though it's gotten a bit better with more recent versions of Python, while still imperfect). So instead of arguments passed by keyword being identified with any of the in_axes, they're just always mapped over their leading axis. (I'd like to revise that behavior, but haven't gotten to it, in part because I know it'll break exist…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@GoktugGuvercin
Comment options

Answer selected by GoktugGuvercin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants