-
I would like to understand why binding primitive ops with First, I constructed my own primitive and defined an implementation rule that properly handles
This seems to occur because binding with My question is:
|
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
Can you say more about what you're trying to accomplish when you were attempting to bind a primitive with a raw |
Beta Was this translation helpful? Give feedback.
-
As Jake says, by default JAX only supports symbolic zeros in the limited context of autodiff rules (which is a private API). Symbolic zeros are not supported in general programs. If you'd like to use them in general programs, then try Quax instead. It even includes a toy |
Beta Was this translation helpful? Give feedback.
-
Here's an example of how to get a symbolic zero from Qax: import jax
import jax.numpy as jnp
import qax
from qax.symbols import SymbolicConstant
@qax.use_implicit_args
def f(x, y):
return jnp.sin(x * y) + x
x = jnp.ones(3)
zero = SymbolicConstant(0, shape=(5, 3), dtype=jnp.float32)
print(f(x, zero))
# Output:
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
print(jax.make_jaxpr(f)(x, zero))
# Output:
# { lambda ; a:f32[3]. let
# _:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
# b:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
# c:f32[3] = squeeze[dimensions=(0,)] b
# d:f32[5,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(5, 3)] c
# in (d,) } You can see from the generated jaxpr that If you have a specific primitive, @qax.primitive_handler(my_primitive)
def handler(primitive, x : SymbolicConstant):
if x.value != 0:
# If the value is non-zero, just create a dense array and run your primitive
return primitive.bind(x.materialize())
return my_special_logic_for_zero(x.shape) |
Beta Was this translation helpful? Give feedback.
-
I have my own primitive and I'm trying to make @jakevdp I'm a bit confused about the "philosophy" of Also, thanks for the pointer to Qax @patrick-kidger and @davisyoshida. It is a very interesting usage of all the tracing machinery to perform these symbolic optimizations! |
Beta Was this translation helpful? Give feedback.
I have my own primitive and I'm trying to make
jvp
andvjp
work with it. Now I think I understand the intended way to do so. I am supposed to "intercept"ad.Zero
when defining thejvp
and also, if thejvp
involves custom primitives, I need to make sure their transposition rule also handlesad.Zero
objects. Basically, I need to write an implementation ofjvp
andtranspose
for every possible combination ofad.Zero
on the input arguments. This is what I was doing wrong. I was using the arguments directly, which resulted in binding primitives with non valid jax types.