-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
After converting the torch.unbind operation using tvm.relax.frontend.torch.from_exported_program, the result, which should have been split into two tensors, was instead split into three. Furthermore, one of these resulting tensors has a dimension of size 0. This led to an error when running on CUDA. The specific Python code is as follows:
import tvm
from tvm import dlight as dl
from tvm.relax.frontend.torch import from_exported_program
import torch
class Unbind(torch.nn.Module):
def forward(self, x):
return torch.unbind(x, dim=0)
model = Unbind()
x = torch.randn(2, 3, 4)
exported_program = torch.export.export(model, (x,))
mod = from_exported_program(exported_program)
dev = tvm.cuda(0)
target = tvm.target.Target.from_device(dev)
with target:
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = dl.ApplyDefaultSchedule(
dl.gpu.Fallback(),
)(mod)
mod.show()
ex = tvm.relax.build(mod, target)
x_tvm = tvm.nd.from_dlpack(x)
vm = tvm.relax.VirtualMachine(ex, dev)
out = vm["main"](x_tvm)
print(out)The TIR generated in this way has the following structure:
for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(0), thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):cc @junrushao
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug