Replies: 1 comment
-
You'll probably find that the second approach will be more efficient, because it avoids mapped indexing operations over the arrays. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Sup Jax,
I have an interesting usecase - I have a struct that contains a 4-dimensional tensor with millions of elements - in which tensor contains 17 physical properties of elements.
On each of those physical properties, or a set of them, a function is done. For example, gravity(), computeVelocity(), computeStrain(), etc.
In Jax, would it be more efficient to pass references to the entire axis of a tensor and fetch data in the functions, or would it be more efficient to fetch each outside of the functions (in the call)
I use a 4-dimensional
vmap
for computation, like `vmap(vmap(vmap(vmap(fun))))(a,b,c)To clarify:
or
on them?
How is this actually managed in jax? would there be any difference?
Thanks.
Beta Was this translation helpful? Give feedback.
All reactions