Skip to content

How can I write a pallas kernel with no pure out-arguments? #23272

Answered by Rifur13
axelfeldmann asked this question in Q&A
Discussion options

You must be logged in to vote

One possible solution is to initialize the output to C, and then do A @ B inside the kernel.
You can initialize the output using input/output aliasing, which allows us to reuse input buffers for outputs.

It will look something like this:

A = jnp.ones((16, 16))
B = jnp.ones((16, 16))
C = jnp.ones((16, 16)) + 20

def kernel(a_ref, b_ref, _, out_ref):
  out_ref[...] += a_ref[...] @ b_ref[...]

out = pl.pallas_call(kernel,
    grid=(1, ),
    out_shape=jax.ShapeDtypeStruct(C.shape, C.dtype),
    name="matmul",
    input_output_aliases={2: 0},
)(A, B, C)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@axelfeldmann
Comment options

Answer selected by axelfeldmann
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants