You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since foo uses control-flow, I can't jax.jit a shard_map-decorated version of it. However, since the inner function jit_batch_inference is already jitted, I expect an unjitted shard_map-decorated foo to be fast, and yet it's 100x slower. How can I speed up run_shard_map_inference in the code below while still allowing the control-flow inside shard_inference?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
The function
foo
that I want to decorate withshard_map
has untraceable control-flow, which I want to avoid refactoring.Since
foo
uses control-flow, I can'tjax.jit
ashard_map
-decorated version of it. However, since the inner functionjit_batch_inference
is already jitted, I expect an unjittedshard_map
-decoratedfoo
to be fast, and yet it's 100x slower. How can I speed uprun_shard_map_inference
in the code below while still allowing the control-flow insideshard_inference
?Beta Was this translation helpful? Give feedback.
All reactions