Skip to content

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Apr 7, 2025

Those ops are required to support classification models from torchvision.
The only remaining unsupported op is index.Tensor but I'm leaving it for now since it's a bit complex.

Current coverage

  • Supported: 70 / 80
  • Not supported yet: 10 / 80
    • Due to the lack of index.Tensor: maxvit_t, swin_b, swin_s, swin_t, swin_v2_b, swin_v2_s, swin_v2_t
    • Output mismatch: efficientnet_b7, inception_v3, mobilenet_v3_small
coverage.py
import torch
import torchvision
import tvm
from termcolor import colored
from torch.export import export
from torchvision.models import get_model, get_model_weights, list_models
from tvm import relax
from tvm.contrib.download import download_testdata
from tvm.relax.frontend.torch import from_exported_program


def check_torch_tvm_result(torch_model, batch):
  example_args = (batch,)

  # PyTorch
  exported_program = export(torch_model, args=example_args)
  expected: torch.Tensor = exported_program.module()(*example_args)

  # Relax
  target = "llvm"
  dev = tvm.cpu()
  mod = from_exported_program(exported_program)
  # return True
  exe = tvm.compile(mod, target=target)
  vm = relax.VirtualMachine(exe, dev)
  tvm_args = [tvm.nd.from_dlpack(x.contiguous()) for x in example_args]
  tvm_output = vm["main"](*tvm_args)[0]
  actual = torch.from_numpy(tvm_output.numpy())

  # check if the outputs match
  return torch.allclose(actual, expected, rtol=1e-5, atol=1e-5, equal_nan=True)


# skip them due to high memory usage
torchvision_models_skip = [
  "regnet_y_128gf",
  "vit_b_32",
  "vit_h_14",
  "vit_l_16",
  "vit_l_32",
]


def main():
  # prepare sample image
  img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
  img_name = "cat.png"
  img_path = download_testdata(img_url, img_name, module="data")

  # torchvision classification models
  image_tensor = torchvision.io.decode_image(img_path)
  model_names = list_models(module=torchvision.models)
  length = len(model_names)
  for idx, model_name in enumerate(model_names, 1):
    if model_name in torchvision_models_skip:
      continue
    try:
      # load model from torchvision
      torch_model = get_model(model_name, weights="DEFAULT").eval()
      weights = get_model_weights(model_name).DEFAULT
      transforms = weights.transforms()
      batch = transforms(image_tensor).unsqueeze(0)
      if not check_torch_tvm_result(torch_model, batch):
        print(
          colored(f"[{idx}/{length}] Output mismatch for {model_name}", "yellow")
        )
      else:
        print(colored(f"[{idx}/{length}] Passed {model_name}", "green"))
    except Exception as e:
      print(colored(f"[{idx}/{length}] Error in {model_name}: {e}", "red"))


if __name__ == "__main__":
  main()

cc @MasterJH5574 @Hzfengsy

@mshr-h mshr-h changed the title [Relax][PyTorch] Support unflatten.int, hardtanh_.default, dropout_.default, silu_.default, add_.Tensor and relu_.default in ExportedProgram frontend [Relax][PyTorch] Improve ExportedProgram frontend by supporting unflatten.int, hardtanh_.default, dropout_.default, silu_.default, add_.Tensor and relu_.default Apr 7, 2025
@mshr-h mshr-h marked this pull request as ready for review April 7, 2025 07:34
@mshr-h mshr-h force-pushed the feat-support-torchvision-classification-models branch from 2b50ebe to ccce4eb Compare April 7, 2025 07:35
@mshr-h
Copy link
Contributor Author

mshr-h commented Apr 7, 2025

@tvm-bot rerun

1 similar comment
@mshr-h
Copy link
Contributor Author

mshr-h commented Apr 7, 2025

@tvm-bot rerun

@github-actions
Copy link
Contributor

github-actions bot commented Apr 7, 2025

Failed to re-run CI in https://github.com/apache/tvm/actions/runs/14306234310

Traceback (most recent call last):
  File "ci/scripts/github/github_tvmbot.py", line 591, in comment_failure
    raise item
  File "ci/scripts/github/github_tvmbot.py", line 697, in run
    pr.rerun_jenkins_ci()
  File "ci/scripts/github/github_tvmbot.py", line 550, in rerun_jenkins_ci
    post(url, auth=("tvm-bot", TVM_BOT_JENKINS_TOKEN))
  File "/home/runner/work/tvm/tvm/ci/scripts/jenkins/git_utils.py", line 53, in post
    with request.urlopen(req, data) as response:
  File "/usr/lib/python3.8/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/usr/lib/python3.8/urllib/request.py", line 531, in open
    response = meth(req, response)
  File "/usr/lib/python3.8/urllib/request.py", line 640, in http_response
    response = self.parent.error(
  File "/usr/lib/python3.8/urllib/request.py", line 569, in error
    return self._call_chain(*args)
  File "/usr/lib/python3.8/urllib/request.py", line 502, in _call_chain
    result = func(*args)
  File "/usr/lib/python3.8/urllib/request.py", line 649, in http_error_default
    raise HTTPError(req.full_url, code, msg, hdrs, fp)
urllib.error.HTTPError: HTTP Error 502: Bad Gateway

with response

<html>
<head><title>502 Bad Gateway</title></head>
<body>
<center><h1>502 Bad Gateway</h1></center>
</body>
</html>

@hugolatendresse
Copy link
Contributor

@mshr-h fyi I'm working on index.Tensor

@mshr-h
Copy link
Contributor Author

mshr-h commented Apr 8, 2025

@tvm-bot rerun

@Hzfengsy Hzfengsy merged commit 32a6f01 into apache:main Apr 8, 2025
16 checks passed
@mshr-h mshr-h deleted the feat-support-torchvision-classification-models branch April 8, 2025 04:37
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
…atten.int`, `hardtanh_.default`, `dropout_.default`, `silu_.default`, `add_.Tensor` and `relu_.default` (apache#17813)

* support `relu_.default`

* support `add_.Tensor`

* support `silu_.default`

* support `dropout_.default`

* support `hardswish_.default`

* support `hardtanh_.default`

* support `unflatten.int`

* fix lint error
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants