Skip to content

How to automatically construct the lowering of a function? #19395

Answered by jakevdp
cgel asked this question in General
Discussion options

You must be logged in to vote

If you want to automatically construct a lowering from a Python impl rule, you can use mlir.lower_fun:

add_lowering = mlir.lower_fun(add_impl, multiple_results=False)

If you want to know more about how to define lowering rules manually, the best documentation available is looking through the JAX source code. For example, here's how jax.lax defines general lowering rules for N-ary operations (in your case, N=2 because there are two inputs): https://github.com/google/jax/blob/1bd22b0fe12b4259e691612d459193513a44737c/jax/_src/lax/lax.py#L1697-L1713

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by cgel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants