Skip to content

Commit 001d5ec

Browse files
authored
[Relax][PyTorch][Docs] Use torch.export insteamd of fx.symbolic_trace for tutorial (#17436)
* use torch.export * in order to make interface consistent, user inputs should be placed first * chore
1 parent abb901f commit 001d5ec

File tree

4 files changed

+56
-52
lines changed

4 files changed

+56
-52
lines changed

docs/get_started/tutorials/ir_module.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
# below.
4141

4242
import torch
43-
from torch import fx, nn
44-
from tvm.relax.frontend.torch import from_fx
43+
from torch import nn
44+
from torch.export import export
45+
from tvm.relax.frontend.torch import from_exported_program
4546

4647
######################################################################
4748
# Import from existing models
@@ -67,13 +68,15 @@ def forward(self, x):
6768
return x
6869

6970

70-
# Give the input shape and data type
71-
input_info = [((1, 784), "float32")]
71+
# Give an example argument to torch.export
72+
example_args = (torch.randn(1, 784, dtype=torch.float32),)
7273

7374
# Convert the model to IRModule
7475
with torch.no_grad():
75-
torch_fx_model = fx.symbolic_trace(TorchModel())
76-
mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
76+
exported_program = export(TorchModel().eval(), example_args)
77+
mod_from_torch = from_exported_program(
78+
exported_program, keep_params_as_input=True, unwrap_unit_return_tuple=True
79+
)
7780

7881
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
7982
# Print the IRModule

docs/how_to/tutorials/e2e_opt_model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
import os
3535
import numpy as np
3636
import torch
37-
from torch import fx
37+
from torch.export import export
3838
from torchvision.models.resnet import ResNet18_Weights, resnet18
3939

40-
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
40+
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT).eval()
4141

4242
######################################################################
4343
# Review Overall Flow
@@ -63,21 +63,19 @@
6363
# Convert the model to IRModule
6464
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6565
# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further
66-
# optimization. Besides the model, we also need to provide the input shape and data type.
66+
# optimization.
6767

6868
import tvm
6969
from tvm import relax
70-
from tvm.relax.frontend.torch import from_fx
70+
from tvm.relax.frontend.torch import from_exported_program
7171

72-
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
73-
74-
# Give the input shape and data type
75-
input_info = [((1, 3, 224, 224), "float32")]
72+
# Give an example argument to torch.export
73+
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)
7674

7775
# Convert the model to IRModule
7876
with torch.no_grad():
79-
torch_fx_model = fx.symbolic_trace(torch_model)
80-
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
77+
exported_program = export(torch_model, example_args)
78+
mod = from_exported_program(exported_program, keep_params_as_input=True)
8179

8280
mod, params = relax.frontend.detach_params(mod)
8381
mod.show()

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
3434

3535
from torch import fx
3636

37-
def create_input_vars(
38-
self, exported_program: torch.export.ExportedProgram
39-
) -> Tuple[List[relax.Var], List[relax.Var]]:
40-
"""Create relax input vars."""
41-
parameters_buffers_constants = []
42-
user_inputs = []
43-
for spec in exported_program.graph_signature.input_specs:
44-
name_hint = spec.arg.name
45-
if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
46-
shape = exported_program.tensor_constants[spec.target].shape
47-
torch_dtype = exported_program.tensor_constants[spec.target].dtype
48-
elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
49-
for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target):
50-
if node.name == name_hint:
51-
shape = node.meta["tensor_meta"].shape
52-
torch_dtype = node.meta["tensor_meta"].dtype
53-
break
54-
else:
55-
# PARAMETER or BUFFER
56-
shape = exported_program.state_dict[spec.target].shape
57-
torch_dtype = exported_program.state_dict[spec.target].dtype
58-
59-
dtype = self._convert_data_type(torch_dtype)
60-
relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype))
61-
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
62-
user_inputs.append(relax_var)
63-
else:
64-
parameters_buffers_constants.append(relax_var)
65-
66-
return parameters_buffers_constants, user_inputs
67-
6837
########## Unary Ops ##########
6938

7039
def _hardtanh(self, node: fx.Node) -> relax.Expr:
@@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var:
178147
stride = [node.args[4] if len(node.args) > 4 else 1]
179148
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride))
180149

150+
########## Others ##########
151+
181152
def create_convert_map(
182153
self,
183154
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -293,6 +264,37 @@ def create_convert_map(
293264
"getitem": self._getitem,
294265
}
295266

267+
def create_input_vars(
268+
self, exported_program: torch.export.ExportedProgram
269+
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
270+
"""Create relax input vars."""
271+
parameters_buffers_constants = OrderedDict()
272+
user_inputs = OrderedDict()
273+
for spec in exported_program.graph_signature.input_specs:
274+
name_hint = spec.arg.name
275+
if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
276+
shape = exported_program.tensor_constants[spec.target].shape
277+
torch_dtype = exported_program.tensor_constants[spec.target].dtype
278+
elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
279+
for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target):
280+
if node.name == name_hint:
281+
shape = node.meta["tensor_meta"].shape
282+
torch_dtype = node.meta["tensor_meta"].dtype
283+
break
284+
else:
285+
# PARAMETER or BUFFER
286+
shape = exported_program.state_dict[spec.target].shape
287+
torch_dtype = exported_program.state_dict[spec.target].dtype
288+
289+
dtype = self._convert_data_type(torch_dtype)
290+
relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype))
291+
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
292+
user_inputs[name_hint] = relax_var
293+
else:
294+
parameters_buffers_constants[name_hint] = relax_var
295+
296+
return parameters_buffers_constants, user_inputs
297+
296298
def from_exported_program(
297299
self,
298300
exported_program: torch.export.ExportedProgram,
@@ -305,7 +307,8 @@ def from_exported_program(
305307

306308
# Create input variables.
307309
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program)
308-
inputs_vars = parameter_buffer_constant_vars + user_input_vars
310+
inputs_vars = user_input_vars.copy()
311+
inputs_vars.update(parameter_buffer_constant_vars)
309312

310313
# Initialize the block builder with a function and a dataflow block.
311314
self.block_builder = relax.BlockBuilder()
@@ -314,7 +317,7 @@ def from_exported_program(
314317

315318
nodes: List[fx.Node] = exported_program.graph.nodes
316319
with self.block_builder.function(
317-
name=func_name, params=inputs_vars.copy(), attrs=func_attrs
320+
name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs
318321
):
319322
output = None
320323
with self.block_builder.dataflow():
@@ -325,7 +328,7 @@ def from_exported_program(
325328
# Ignore sym input
326329
continue
327330

328-
self.env[node] = inputs_vars.pop(0)
331+
self.env[node] = inputs_vars[node.name]
329332
elif node.op == "output":
330333
args = self.retrieve_args(node)
331334
assert len(args) == 1

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3550,9 +3550,9 @@ def forward(self, input):
35503550
class expected1:
35513551
@R.function
35523552
def main(
3553+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
35533554
conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
35543555
conv_bias: R.Tensor((6,), dtype="float32"),
3555-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
35563556
) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
35573557
R.func_attr({"num_input": 1})
35583558
# block 0
@@ -3586,7 +3586,7 @@ def main(
35863586
params = params["main"]
35873587

35883588
assert len(params) == len(func.params) - 1
3589-
for param_var, param_ndarray in zip(func.params[:-1], params):
3589+
for param_var, param_ndarray in zip(func.params[1:], params):
35903590
assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape
35913591
assert param_var.struct_info.dtype == param_ndarray.dtype
35923592

0 commit comments

Comments
 (0)