Skip to content

[Bug] pytorch to relay : the order of input nodes is not preserved. #14461

@sweetcocoa

Description

@sweetcocoa

Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first 😸

Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed.

Expected behavior and Actual behavior

For pytorch modules that use multiple inputs,
There is an unexpected problem when converting a pytorch module into a relay graph using relay.frontend.from_pytorch.

The problem is that the input nodes of the relay graph are not determined in the order of the input nodes specified in input_infos passed as arguments to relay.frontend.from_pytorch.

This means that the order of example inputs written to torch.jit.trace and the order of inputs used for set_input of tvm graph executor are different. This in turn results in incorrect output from running the model.

Example code below.

Actual behavior

Environment

Ubuntu / miniconda python3.10

apache-tvm==0.11.1
transformers==4.27.4
torch==2.0.0+cpu
tensorflow==2.12.0

Steps to reproduce

import tvm
import tvm.relay as relay
from tvm.contrib import graph_executor
import torch
from transformers import BertForMaskedLM

torch.random.manual_seed(0)

model = BertForMaskedLM.from_pretrained("bert-base-uncased", return_dict=False).eval()

# make example inputs and jit-trace
input_info = [
    ("input_ids", ([1, 128], torch.long)),
    ("attention_mask", ([1, 128], torch.long)),
    ("token_type_ids", ([1, 128], torch.long)),
]
example_inputs = [
    torch.randint(0, 2, sz, dtype=dtype, requires_grad=False) for name, (sz, dtype) in input_info
]
scripted_model = torch.jit.trace(model, example_inputs=example_inputs).eval()

# jit module to relay
input_info = [
    ("input_ids", ([1, 128], "long")),
    ("attention_mask", ([1, 128], "long")),
    ("token_type_ids", ([1, 128], "long")),
]

# HERE IS PROBLEMATIC LINE
mod, params = relay.frontend.from_pytorch(scripted_model, input_info)

"""
the order of input node is different to the order of input_info

input info (which is provided by me) : 
[0] input_ids
[1] attention_mask
[2] token_type_ids

input nodes (relay)

[0] input_ids
[1] token_type_ids
[2] attention_mask

mod : 
def @main(%input_ids: Tensor[(1, 128), int64], %token_type_ids: Tensor[(1, 128), int64], %attention_mask: Tensor[(1, 128), int64], ....)
"""

# Build the graph
lib = relay.build(mod, target="llvm", params=params)
dev = tvm.device("llvm", 0)


# Excecute the graph with the input : example_inputs
module = graph_executor.GraphModule(lib["default"](dev))
for i in range(3):
    module.set_input(i, example_inputs[i].numpy())
module.run()

tvm_output = module.get_output(0).numpy()
torch_output = model(*example_inputs)

# This tvm_output is invalid result.
print(tvm_output)
print(torch_output)

Preferably a minimal script to cause the issue to occur.

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions