Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to write in-place custom ops compatible with torch.compile using pallas #8385

Open
soodoshll opened this issue Nov 15, 2024 · 5 comments

Comments

@soodoshll
Copy link

❓ Questions and Help

I'm trying to implement an in-place operator using pallas, and wrap it as a torch custom op. However, I found it difficult to make it work with torch.compile. More specifically, I’m unclear about how to set donation, input-output aliases, and the op schema. It seems having an output aliased with the input will leads to functionalization problems in torch compiler.

Thanks!

My script is like this:

from typing import List, Callable
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import torch
import torch_xla
from torch_xla.experimental import custom_kernel
from functools import partial


def plus_one_kernel(x_ref, o_ref):
    o_ref[:] = o_ref[:] + 1

@partial(jax.jit, donate_argnums=[0])
def plus_one_pallas(x: jax.Array):
    size = x.shape[0]
    return pl.pallas_call(
        plus_one_kernel,
        grid=(1, 1),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        input_output_aliases={0:0}
    )(x)

@torch.library.custom_op("xla::plus_one_", mutates_args=("x", ))
def plus_one_(x: torch.Tensor) -> None:
    plus_one_pt = torch_xla.experimental.custom_kernel.make_kernel_from_pallas(
        plus_one_pallas, output_shape_dtype_fn = lambda x: [(x.shape, x.dtype)]
    )
    plus_one_pt(x)

def fn(x):
    torch.ops.xla.dynamo_set_buffer_donor_(x, True)
    return plus_one_(x)

fn = torch.compile(fn, backend="openxla")

x = torch.ones(4, dtype=torch.bfloat16, device='xla')

fn(x)
print(x)

And it seems it does not change the value of x.

@JackCaoG
Copy link
Collaborator

You are right, when I talked to @bdhirsh my take away was it is better to make the custom op functional. I think functionization will not run inside a custom op.

If you want to enable the buffer aliasing with torch.compile, you can use

torch.ops.xla.dynamo_set_buffer_donor_(input, True)
to mark the input tensor to be a donor. For the non-torch.compile path, lazy tensor will try to smartly determine which buffer can be aliased with the output

@soodoshll
Copy link
Author

soodoshll commented Nov 19, 2024

Thanks for you response!

However, we observed another issue, that is in the compiled graph of this in-place plus-one operator, there is a redundant copy node (red). It only exists if torch.compile is used.

image

I tried to locate where this copy is from, and according to the dumped XLA IR, it does not exist in the original graph before optimizaiton. The compiler also successfully identify the input-output alias.

XLA IR
def forward(self, arg0_1):
    dynamo_set_buffer_donor_ = torch.ops.xla.dynamo_set_buffer_donor_.default(arg0_1, True);  arg0_1 = None
    plus_one_ = torch.ops.xla.plus_one_.default(dynamo_set_buffer_donor_);  dynamo_set_buffer_donor_ = None
    return plus_one_

Number of HLO Input: 1
Number of HLO Output: 1
Number of HLO Input can be aliased with Output: 1
XLA IR Text:
IR {
  %0 = s32[4]{0} xla::device_data(), xla_shape=s32[4]{0}
  %1 = (s32[4]{0}) xla::tpu_custom_call(%0), xla_shape=(s32[4]{0}), ROOT=0
}
code
from typing import List, Callable
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import torch
import torch_xla
from torch_xla.experimental import custom_kernel
from functools import partial
import torch_xla.debug.profiler as xp

def plus_one_kernel(x_ref, o_ref):
    o_ref[...] = x_ref[...] + 1

def plus_one_pallas(x: jax.Array):
    size = x.shape
    return pl.pallas_call(
        plus_one_kernel,
        grid=[1],
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x)

@torch.library.custom_op("xla::plus_one_", mutates_args={})
def plus_one_(x: torch.Tensor) -> torch.Tensor:
    plus_one_pt = torch_xla.experimental.custom_kernel.make_kernel_from_pallas(
        plus_one_pallas, output_shape_dtype_fn = lambda x: [(x.shape, x.dtype)]
    )
    return plus_one_pt(x)

@plus_one_.register_fake
def plus_one_fake(x: torch.Tensor) -> torch.Tensor:
    return x

def fn(x):
    torch.ops.xla.dynamo_set_buffer_donor_(x, True)
    ret = plus_one_(x)
    return ret

server = xp.start_server(9012)
profile_logdir = "./profile" 
xp.trace_detached('localhost:9012', profile_logdir)

fn = torch.compile(fn, backend="openxla")
x = torch.ones(4, dtype=torch.int32, device='xla')

ret = fn(x)
print(ret)

@JackCaoG
Copy link
Collaborator

Let me take a look...

@JackCaoG
Copy link
Collaborator

I also run your code and saw

## BEGIN_GRAPH
HloModule IrToHlo.4, entry_computation_layout={(s32[4]{0})->(s32[4]{0})}

ENTRY %IrToHlo.4 (p0.1: s32[4]) -> (s32[4]) {
  %p0.1 = s32[4]{0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk4/pytorch/xla/torch_xla/_dynamo/dynamo_bridge.py" source_line=742}
  %custom-call.2 = s32[4]{0} custom-call(s32[4]{0} %p0.1), custom_call_target="tpu_custom_call", operand_layout_constraints={s32[4]{0}}, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk4/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=172}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAEjCQEDBQcBAwkDDwsNDxETFRcD370PAbcHCxMLCw8LDw8LDwsLCzMLCxMPCwsbCw9DCw8LZQszCwsLCxMbCxsLGxsPDwsPEw8PCxMPDwsXDw8LFw8PCxMPDwsXDw8LFwsPDwsTCw8PCxcPCxcLEw8LEwsLBQWRYQcDWQEPDw8bEwcbFwLCBh8FGQMDCVUFGwUdFVtfBR8drwsdtQsFIREDAQUjBSUFJyMDAxEEAAAAAAAAAAUpDQ0DAwlXHVkLBSsFLQMFLS8DEwUvEQMNAw8zNQc3Oz0/FUEVA0NFRwUxAQO3DQthZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABTMjAwMRAQAAAAAAAAAFNQU3BTkFOwEFSU0DBRdLGx0JGQMFF08bHQkfAwUHIQMZAwUHIQMfEQEBEQkBBT0dE10XDRsBFWFnHWNlBT8XDSMBFWlvHWttBUEXJzYCARVxdx1zdQVDFyeGAgEVeX8de30FRRcNOQEVgYcdg4UFRxcpRgUBFYmRHYuNBUkXj0YLAQVLFZObHZWXBU0XmUEBBU8VnaMdn6EFURcpxgQBHaWnBVMXqWILAQVVAwMJrREBBQVXAwOzuwVZBVsjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGgub3ZlcmZsb3c8bm9uZT4AAQICAQIEF7kDEQE5JwMRAQMFBwEFBQEFAwEDAQQyAgUBEQErBwMBDQURATEHAxUnBwEBBQEFAQMDJSMDCQkGJQMHBQMHAwMPqwMBCwYPAwcDCw0HD7EDBwUJDQMDESMDCQkGEQMHBQURDwQRBw8FEQcAAQURAVEHAwcPAwEBAwMBBQMBAwMBBQMBBwQBAwMFEQFTBwMHDwMBAQMDAQUDAQMDAQUDAQcEAQMDBgMBBQEAmgxdDR0LTRcbVQ1pCRcVHxshCx0LIyEjKS1riRkdGSUhSQ0dExsXIxkZFR8PDQkdEWJ1aWx0aW4Ac3RhYmxlX21vc2FpYwB0cHUAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB2ZWN0b3IuYnJvYWRjYXN0AGFyaXRoLmFkZGkAdmVjdG9yLnN0b3JlAHN5bV9uYW1lAGZ1bmN0aW9uX3R5cGUAdmFsdWUAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3Rlc3QucHkAcGx1c19vbmVfa2VybmVsAHRyYW5zZm9ybV9pbmRpY2VzAHRyYW5zZm9ybV8wAHdpbmRvd19ib3VuZHMAdHJhbnNmb3JtXzEAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3RvcmNoX3hsYS9leHBlcmltZW50YWwvY3VzdG9tX2tlcm5lbC5weQAvd29ya3NwYWNlcy9kazQvcHl0b3JjaC90b3JjaC9fbGlicmFyeS9jdXN0b21fb3BzLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAC9nZXQAcGx1c19vbmVfcGFsbGFzAHRyYWNlX3BhbGxhcwB3cmFwcGVkX2tlcm5lbABwbHVzX29uZV8Ad3JhcHBlZF9mbgBfZm4AL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gvdG9yY2gvX2R5bmFtby9ldmFsX2ZyYW1lLnB5AGlubmVyAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19jb21waWxlLnB5AGJhY2tlbmRfaW1wbAByZWRpc3BhdGNoAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19vcHMucHkAL2FkZABvdmVyZmxvd0ZsYWdzAC9zd2FwAA==", "serialization_format": 1, "needs_layout_passes": true}}
  ROOT %tuple.3 = (s32[4]{0}) tuple(s32[4]{0} %custom-call.2)
}

This is running with dynamo, at least in the HLO we passed to the XLA it does not have the copy. Let me see which stage this copy is added.

@JackCaoG
Copy link
Collaborator

so I run your code with

XLA_FLAGS=--xla_dump_to=/tmp/xla_dump

to collect HLOs at different stages, I saw

root@t1v-n-a83b02ef-w-0:/tmp/xla_dump# ls | grep module_0003
module_0003.SyncTensorsGraph.4.after_codegen.txt
module_0003.SyncTensorsGraph.4.after_optimizations-buffer-assignment.txt
module_0003.SyncTensorsGraph.4.after_optimizations-memory-usage-report.txt
module_0003.SyncTensorsGraph.4.after_optimizations.txt
module_0003.SyncTensorsGraph.4.after_optimizations_after_buffer_assignment.txt
module_0003.SyncTensorsGraph.4.after_optimizations_before_buffer_assignment.txt
module_0003.SyncTensorsGraph.4.before_optimizations.txt
module_0003.SyncTensorsGraph.4.execution_options.txt
module_0003.SyncTensorsGraph.4.flagfile
module_0003.SyncTensorsGraph.4.hlo_module_config.txt
module_0003.SyncTensorsGraph.4.target_arguments.txt
module_0003.SyncTensorsGraph.4.tpu_comp_env.txt
module_0003.SyncTensorsGraph.4.transfer_stats.txt

in /tmp/xla_dump/module_0003.SyncTensorsGraph.4.before_optimizations.txt, we can see

HloModule SyncTensorsGraph.4, buffer_donor={ (0, {}) }, entry_computation_layout={(s32[4]{0:T(128)})->(s32[4]{0:T(128)})}, replica_count=4

ENTRY SyncTensorsGraph.4 {
  p0.1 = s32[4]{0} parameter(0)
  custom-call.2 = s32[4]{0} custom-call(p0.1), custom_call_target="tpu_custom_call", operand_layout_constraints={s32[4]{0}}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAEjCQEDBQcBAwkDDwsNDxETFRcD370PAbcHCxMLCw8LDw8LDwsLCzMLCxMPCwsbCw9DCw8LZQszCwsLCxMbCxsLGxsPDwsPEw8PCxMPDwsXDw8LFw8PCxMPDwsXDw8LFwsPDwsTCw8PCxcPCxcLEw8LEwsLBQWRYQcDWQEPDw8bEwcbFwLCBh8FGQMDCVUFGwUdFVtfBR8drwsdtQsFIREDAQUjBSUFJyMDAxEEAAAAAAAAAAUpDQ0DAwlXHVkLBSsFLQMFLS8DEwUvEQMNAw8zNQc3Oz0/FUEVA0NFRwUxAQO3DQthZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABTMjAwMRAQAAAAAAAAAFNQU3BTkFOwEFSU0DBRdLGx0JGQMFF08bHQkfAwUHIQMZAwUHIQMfEQEBEQkBBT0dE10XDRsBFWFnHWNlBT8XDSMBFWlvHWttBUEXJzYCARVxdx1zdQVDFyeGAgEVeX8de30FRRcNOQEVgYcdg4UFRxcpRgUBFYmRHYuNBUkXj0YLAQVLFZObHZWXBU0XmUEBBU8VnaMdn6EFURcpxgQBHaWnBVMXqWILAQVVAwMJrREBBQVXAwOzuwVZBVsjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGgub3ZlcmZsb3c8bm9uZT4AAQICAQIEF7kDEQE5JwMRAQMFBwEFBQEFAwEDAQQyAgUBEQErBwMBDQURATEHAxUnBwEBBQEFAQMDJSMDCQkGJQMHBQMHAwMPqwMBCwYPAwcDCw0HD7EDBwUJDQMDESMDCQkGEQMHBQURDwQRBw8FEQcAAQURAVEHAwcPAwEBAwMBBQMBAwMBBQMBBwQBAwMFEQFTBwMHDwMBAQMDAQUDAQMDAQUDAQcEAQMDBgMBBQEAmgxdDR0LTRcbVQ1pCRcVHxshCx0LIyEjKS1riRkdGSUhSQ0dExsXIxkZFR8PDQkdEWJ1aWx0aW4Ac3RhYmxlX21vc2FpYwB0cHUAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB2ZWN0b3IuYnJvYWRjYXN0AGFyaXRoLmFkZGkAdmVjdG9yLnN0b3JlAHN5bV9uYW1lAGZ1bmN0aW9uX3R5cGUAdmFsdWUAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3Rlc3QucHkAcGx1c19vbmVfa2VybmVsAHRyYW5zZm9ybV9pbmRpY2VzAHRyYW5zZm9ybV8wAHdpbmRvd19ib3VuZHMAdHJhbnNmb3JtXzEAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3RvcmNoX3hsYS9leHBlcmltZW50YWwvY3VzdG9tX2tlcm5lbC5weQAvd29ya3NwYWNlcy9kazQvcHl0b3JjaC90b3JjaC9fbGlicmFyeS9jdXN0b21fb3BzLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAC9nZXQAcGx1c19vbmVfcGFsbGFzAHRyYWNlX3BhbGxhcwB3cmFwcGVkX2tlcm5lbABwbHVzX29uZV8Ad3JhcHBlZF9mbgBfZm4AL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gvdG9yY2gvX2R5bmFtby9ldmFsX2ZyYW1lLnB5AGlubmVyAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19jb21waWxlLnB5AGJhY2tlbmRfaW1wbAByZWRpc3BhdGNoAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19vcHMucHkAL2FkZABvdmVyZmxvd0ZsYWdzAC9zd2FwAA==", "serialization_format": 1, "needs_layout_passes": true}}
  ROOT tuple.3 = (s32[4]{0}) tuple(custom-call.2)
}

There is no copy and the buffer donor was setup correctly buffer_donor={ (0, {}) }.

in module_0003.SyncTensorsGraph.4.after_optimizations.txt, I see

HloModule SyncTensorsGraph.4, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias) }, entry_computation_layout={(s32[4]{0:T(128)})->(s32[4]{0:T(128)})}, replica_count=4

ENTRY SyncTensorsGraph.4 {
  p0.1 = s32[4]{0:T(128)} parameter(0)
  copy.3 = s32[4]{0:T(128)S(3)} copy(p0.1)
  name = s32[4]{0:T(128)} custom-call(copy.3), custom_call_target="tpu_custom_call", operand_layout_constraints={s32[4]{0}}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAEjCQEDBQcBAwkDDwsNDxETFRcD370PAbcHCxMLCw8LDw8LDwsLCzMLCxMPCwsbCw9DCw8LZQszCwsLCxMbCxsLGxsPDwsPEw8PCxMPDwsXDw8LFw8PCxMPDwsXDw8LFwsPDwsTCw8PCxcPCxcLEw8LEwsLBQWRYQcDWQEPDw8bEwcbFwLCBh8FGQMDCVUFGwUdFVtfBR8drwsdtQsFIREDAQUjBSUFJyMDAxEEAAAAAAAAAAUpDQ0DAwlXHVkLBSsFLQMFLS8DEwUvEQMNAw8zNQc3Oz0/FUEVA0NFRwUxAQO3DQthZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABTMjAwMRAQAAAAAAAAAFNQU3BTkFOwEFSU0DBRdLGx0JGQMFF08bHQkfAwUHIQMZAwUHIQMfEQEBEQkBBT0dE10XDRsBFWFnHWNlBT8XDSMBFWlvHWttBUEXJzYCARVxdx1zdQVDFyeGAgEVeX8de30FRRcNOQEVgYcdg4UFRxcpRgUBFYmRHYuNBUkXj0YLAQVLFZObHZWXBU0XmUEBBU8VnaMdn6EFURcpxgQBHaWnBVMXqWILAQVVAwMJrREBBQVXAwOzuwVZBVsjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGgub3ZlcmZsb3c8bm9uZT4AAQICAQIEF7kDEQE5JwMRAQMFBwEFBQEFAwEDAQQyAgUBEQErBwMBDQURATEHAxUnBwEBBQEFAQMDJSMDCQkGJQMHBQMHAwMPqwMBCwYPAwcDCw0HD7EDBwUJDQMDESMDCQkGEQMHBQURDwQRBw8FEQcAAQURAVEHAwcPAwEBAwMBBQMBAwMBBQMBBwQBAwMFEQFTBwMHDwMBAQMDAQUDAQMDAQUDAQcEAQMDBgMBBQEAmgxdDR0LTRcbVQ1pCRcVHxshCx0LIyEjKS1riRkdGSUhSQ0dExsXIxkZFR8PDQkdEWJ1aWx0aW4Ac3RhYmxlX21vc2FpYwB0cHUAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB2ZWN0b3IuYnJvYWRjYXN0AGFyaXRoLmFkZGkAdmVjdG9yLnN0b3JlAHN5bV9uYW1lAGZ1bmN0aW9uX3R5cGUAdmFsdWUAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3Rlc3QucHkAcGx1c19vbmVfa2VybmVsAHRyYW5zZm9ybV9pbmRpY2VzAHRyYW5zZm9ybV8wAHdpbmRvd19ib3VuZHMAdHJhbnNmb3JtXzEAL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gveGxhL3RvcmNoX3hsYS9leHBlcmltZW50YWwvY3VzdG9tX2tlcm5lbC5weQAvd29ya3NwYWNlcy9kazQvcHl0b3JjaC90b3JjaC9fbGlicmFyeS9jdXN0b21fb3BzLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAC9nZXQAcGx1c19vbmVfcGFsbGFzAHRyYWNlX3BhbGxhcwB3cmFwcGVkX2tlcm5lbABwbHVzX29uZV8Ad3JhcHBlZF9mbgBfZm4AL3dvcmtzcGFjZXMvZGs0L3B5dG9yY2gvdG9yY2gvX2R5bmFtby9ldmFsX2ZyYW1lLnB5AGlubmVyAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19jb21waWxlLnB5AGJhY2tlbmRfaW1wbAByZWRpc3BhdGNoAC93b3Jrc3BhY2VzL2RrNC9weXRvcmNoL3RvcmNoL19vcHMucHkAL2FkZABvdmVyZmxvd0ZsYWdzAC9zd2FwAA==", "serialization_format": 1, "needs_layout_passes": true}}
  ROOT tuple = (s32[4]{0:T(128)}) tuple(name)
}

so it confirms that this copy was added by the compiler. The only difference I spot for this S(3) I think this means it uses a stride of 3, and this is most likely why this copy is triggered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants