-
Hello! Judging from the stablehlo / make_japxr() output of :
it seems that JAX doesn't recognize that both array arguments are identical. I'm guessing this means that things like optimization of loads, cache performance, etc. is sub-optimal. Is there a way to indicate to JAX that arguments for a particular invocation of a given signature are identical? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! If you want a single argument, you should write When you write You could imagine an example-arguments-to-type policy which automatically de-duplicated, by value or object id. But that would have its own surprises: what would the calling convention be? To pass a single argument (even though we called When things are ambiguous, we tend to prefer the simplest policies, requiring the user to be explicit. That is, in this case we require you de-duplicate the inputs yourself. Luckily you can always write a wrapper which performs that logic automatically, if you prefer it! What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
If you want a single argument, you should write
jit(lambda x: some_func(x, x)).lower(my_arr, my_arr).as_text()
.When you write
f.lower(*args)
, it's really a convenience layer over specifying a type at which you want to tracef
. We don't look at the values ofargs
, or their object ids; we effectively replace them with something likeargs = [jax.ShapeDtypeStruct(x.shape, x.dtype) for x in args]
.You could imagine an example-arguments-to-type policy which automatically de-duplicated, by value or object id. But that would have its own surprises: what would the calling convention be? To pass a single argument (even though we called
lower
with two arguments)? Or to pass…