Skip to content

lax.conv and vmap #4092

Answered by mattjj
kyunghyuncho asked this question in General
Aug 18, 2020 · 3 comments · 3 replies
Discussion options

You must be logged in to vote

I think the numpy/scipy functions might not support channels (e.g. in how the singleton axes are added here).

Actually this is an old wart in JAX, and one we've never gotten around to fixing: see #381 (and probably several duplicates)! The right thing to do is adjust our conv primitive not to require a batch dimension. We've never done it, to our great shame, and that is an impediment to the vmap dream of liberating ourselves from batch dimensions. (Tangentially, another impediment was batch norm, but vmap recently gained support for naming the mapped axis and applying collectives, so that's no longer an issue!)

We should finally fix this conv batch issue once and for all. I'll mention it…

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@georgedahl
Comment options

@cwhy
Comment options

@cwhy
Comment options

Answer selected by kyunghyuncho
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
5 participants