Zero addition not simplified by jit #23266
Unanswered
jeffreymepstein
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question! The issue is that jaxprs are just a representation of the actual operations your program defines, not a representation of the optimizations done by the compiler. If you want to see the result of compilation, you can do so using the tools described in Ahead-of-time lowering and compilation. For example: >>> print(jax.jit(lambda x: 0+x).lower(1).compile().as_text())
HloModule jit__lambda_, entry_computation_layout={(s32[])->s32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.2 (Arg_0.1: s32[]) -> s32[] {
%Arg_0.1 = s32[] parameter(0)
ROOT %copy = s32[] copy(s32[] %Arg_0.1)
} Here you can see that after compilation, the zero-addition is elided by the compiler, and the output of the JIT-compiled version is a simple copy of the input. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm wondering why the following
jax.make_jaxpr(jax.jit(lambda x: 0+x))(1)
evaluates to
It seems like
add 0 c
could just be dropped. Some googling suggests that C compilers avoid a similar simplification because0.+x
does not evaluate tox
whenx=-0.
. Is the same issue what's going on here? It seems like this shouldn't be an issue for integers. Is there some config setting that allows this simplification to be made in the case of either integers or floats?For context, the question that I'm really interested in, and which I think should be related to this, is why the jit compiled function
jax.jit(lambda x: jnp.eye(n) @ x)
is slower thanjax.jit(lambda x: x)
.Beta Was this translation helpful? Give feedback.
All reactions