Skip to content

Performance / optimization with identical array arguments #23388

Closed Answered by mattjj
JeffGreen asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

If you want a single argument, you should write jit(lambda x: some_func(x, x)).lower(my_arr, my_arr).as_text().

When you write f.lower(*args), it's really a convenience layer over specifying a type at which you want to trace f. We don't look at the values of args, or their object ids; we effectively replace them with something like args = [jax.ShapeDtypeStruct(x.shape, x.dtype) for x in args].

You could imagine an example-arguments-to-type policy which automatically de-duplicated, by value or object id. But that would have its own surprises: what would the calling convention be? To pass a single argument (even though we called lower with two arguments)? Or to pass…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@JeffGreen
Comment options

Answer selected by JeffGreen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants