Skip to content

Metadata removed when rewrite rule is applied during torch.onnx.export #2637

@nickfraser

Description

@nickfraser

Context

Before explaining my issue, please understand (and forgive) the following caveats:

  • I'm not sure if my issue is expected behaviour or not. If it is expected behaviour please feel free to close the issue; and
  • I'm not interacting with onnxscript directly, but I believe I've traced my issue to a call to onnxscript, so I believe the issue belongs here. Please feel free to correct me if I'm wrong on this.

For further context, please see this issue.

The Issue

With that out-of-the-way, let me explain the issue. It appears when onnxscript does not preserve any metadata when it applies rewrite rules on mergeable operations (in my example, with consecutive transpose operations).

Consider the following code:

import torch

class DoubleTranspose(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, k):
        k = k.transpose(1,2)
        k = k.transpose(-2, -1) # Two transposes get merged and have no metadata
        return 2*k # Some other op which does have metadata

k = torch.rand(1,128,12,32)
model = DoubleTranspose()
output = model(k) # Make sure the forward path runs

with torch.no_grad():
    torch.onnx.export(
        model,
        (k),
        "minimal.onnx",
        input_names=['k'],
        opset_version=18,
        dynamo=True,
        optimize=True, # set `optimize=False` to prevent the transpose nodes from being merged
    )

When optimize=True, the 2 transposes are merged into a single transpose and no longer contain any metadata, as shown:

Image

With optimize=False, you get the two transposes with metadata intact:

Image

In tracing the PyTorch export function, I'm led to the following function in onnxscript when optimize=True during PyTorch export. Is there any way to apply these optimization passes while automagically adding some metadata to the generated node? Either as a argument to torch.onnx.export or as an alternative post-export step would be of interest here.

Environment

  • OS: Ubuntu 24.04
  • PyTorch version: 2.7.1
  • ONNXScript versions: 0.4.0, 0.5.3

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions