Skip to content

Commit 8289509

Browse files
committed
use torch.export
1 parent 24fd037 commit 8289509

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

docs/get_started/tutorials/ir_module.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
# below.
4141

4242
import torch
43-
from torch import fx, nn
44-
from tvm.relax.frontend.torch import from_fx
43+
from torch import nn
44+
from torch.export import export
45+
from tvm.relax.frontend.torch import from_exported_program
4546

4647
######################################################################
4748
# Import from existing models
@@ -67,13 +68,13 @@ def forward(self, x):
6768
return x
6869

6970

70-
# Give the input shape and data type
71-
input_info = [((1, 784), "float32")]
71+
# Give an example argument to torch.export
72+
example_args = (torch.randn(1, 784, dtype=torch.float32),)
7273

7374
# Convert the model to IRModule
7475
with torch.no_grad():
75-
torch_fx_model = fx.symbolic_trace(TorchModel())
76-
mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
76+
exported_program = export(TorchModel().eval(), example_args)
77+
mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True)
7778

7879
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
7980
# Print the IRModule

docs/how_to/tutorials/e2e_opt_model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
import os
3535
import numpy as np
3636
import torch
37-
from torch import fx
37+
from torch.export import export
3838
from torchvision.models.resnet import ResNet18_Weights, resnet18
3939

40-
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
40+
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
4141

4242
######################################################################
4343
# Review Overall Flow
@@ -63,21 +63,19 @@
6363
# Convert the model to IRModule
6464
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6565
# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further
66-
# optimization. Besides the model, we also need to provide the input shape and data type.
66+
# optimization.
6767

6868
import tvm
6969
from tvm import relax
70-
from tvm.relax.frontend.torch import from_fx
70+
from tvm.relax.frontend.torch import from_exported_program
7171

72-
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
73-
74-
# Give the input shape and data type
75-
input_info = [((1, 3, 224, 224), "float32")]
72+
# Give an example argument to torch.export
73+
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
7674

7775
# Convert the model to IRModule
7876
with torch.no_grad():
79-
torch_fx_model = fx.symbolic_trace(torch_model)
80-
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
77+
exported_program = export(torch_model, example_args)
78+
mod = from_exported_program(exported_program, keep_params_as_input=True)
8179

8280
mod, params = relax.frontend.detach_params(mod)
8381
mod.show()

0 commit comments

Comments
 (0)