-
I am creating my own To make it more concrete, I would like to have an implementation of import jax
from jax import core
import numpy as np
add_p = core.Primitive("add")
def add(x,y):
return add_p.bind(x,y)
def add_impl(x,y):
return x + y
add_p.def_impl(add_impl)
def add_abstract_eval(x,y):
assert x.shape == y.shape
assert x.dtype == y.dtype
return core.ShapedArray(x.shape, x.dtype)
add_p.def_abstract_eval(add_abstract_eval)
def add_lowering(ctx,x,y):
return automatic_lowering_from_impl(add_impl, ctx, x, y)
from jax.interpreters import mlir
mlir.register_lowering(add_p, add_lowering) Also, I would really appreciate some help to better understand the process of lowering. I'm seeing many terms like xla, mlir, stablehlo, mhlo etc... But I'm not even sure which type of object I should be trying to construct in the lowering function. I didn't manage to find any documentation for this. Some more specific questions:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
If you want to automatically construct a lowering from a Python impl rule, you can use add_lowering = mlir.lower_fun(add_impl, multiple_results=False) If you want to know more about how to define lowering rules manually, the best documentation available is looking through the JAX source code. For example, here's how |
Beta Was this translation helpful? Give feedback.
If you want to automatically construct a lowering from a Python impl rule, you can use
mlir.lower_fun
:If you want to know more about how to define lowering rules manually, the best documentation available is looking through the JAX source code. For example, here's how
jax.lax
defines general lowering rules for N-ary operations (in your case, N=2 because there are two inputs): https://github.com/google/jax/blob/1bd22b0fe12b4259e691612d459193513a44737c/jax/_src/lax/lax.py#L1697-L1713