Skip to content

Commit cf6e61b

Browse files
[Relax] Allow ingesting Upsample module from torch.export either using Size or Scale Factor argument (#17721)
Torch's Upsample module can only accomodate the Size or Scale Factor argument but not both. Regarding that limitation, there was previously a bug regarding the parsing of the arguments (the wrong order of arguments was assume). This PR fixes that and adds a unit test.
1 parent d0de906 commit cf6e61b

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,12 @@ def _group_norm(self, node: fx.Node) -> relax.Var:
9292
)
9393

9494
def _upsample_impl(
95-
self, x: relax.Expr, size, align_corners: bool, scale_factor, method: str
95+
self,
96+
x: relax.Expr,
97+
size,
98+
scale_factor,
99+
method: str,
100+
align_corners: bool,
96101
) -> relax.Var:
97102
coord_trans = "align_corners" if align_corners else "half_pixel"
98103

@@ -119,17 +124,39 @@ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var:
119124
align_corners = (
120125
node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True)
121126
)
122-
scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None)
123-
return self._upsample_impl(x, size, align_corners, scale_factor, "linear")
127+
scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1)
128+
return self._upsample_impl(
129+
x, size=size, scale_factor=scale_factor, method="linear", align_corners=align_corners
130+
)
124131

125132
def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
126133
x = self.env[node.args[0]]
127134
size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None)
128-
align_corners = (
129-
node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True)
135+
136+
if size:
137+
scale_factor = None # Can only define size or scale_factor, not both
138+
align_corners = (
139+
node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None)
140+
)
141+
142+
else:
143+
# TODO figure out why pytorch export passes a list such as
144+
# [scale_factor,scale_factor] instead of just an int for
145+
# scale_factor. Using first element for now
146+
scale_factor = (
147+
node.args[2][0] if len(node.args) > 2 else node.kwargs.get("scale_factor", 1)
148+
)
149+
align_corners = (
150+
node.args[3] if len(node.args) > 3 else node.kwargs.get("align_corners", None)
151+
)
152+
153+
return self._upsample_impl(
154+
x,
155+
size=size,
156+
scale_factor=scale_factor,
157+
method="nearest_neighbor",
158+
align_corners=align_corners,
130159
)
131-
scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None)
132-
return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor")
133160

134161
########## Manipulation ##########
135162

tests/python/relax/test_from_exported_to_cuda.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import tvm
19+
from tvm import relax
20+
import tvm.testing
1821
import numpy as np
1922
import torch
2023
from torch.export import export
21-
22-
import tvm
23-
import tvm.testing
24-
from tvm import relax
2524
from tvm.relax.frontend.torch import from_exported_program
25+
from torch.nn import Softmax, Upsample
2626

2727

2828
def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev):
@@ -42,8 +42,6 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
4242
tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch)
4343

4444
relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda()))
45-
# TODO try pipeline below?
46-
# releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target)
4745
ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline)
4846
vm = relax.VirtualMachine(ex, dev)
4947

@@ -57,6 +55,42 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
5755
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
5856

5957

58+
@tvm.testing.parametrize_targets("cuda")
59+
def test_upsample_with_size(target, dev):
60+
"""
61+
The Upsample module can be used with the size arugment or the scale
62+
factor argument but not both. This tests the former.
63+
"""
64+
batch_size = 1
65+
channels = 3
66+
height, width = 8, 8
67+
68+
torch_module = Upsample(size=(64, 64), mode="nearest", recompute_scale_factor=None)
69+
70+
raw_data = np.random.rand(batch_size, channels, height, width).astype("float32")
71+
72+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
73+
74+
75+
@tvm.testing.parametrize_targets("cuda")
76+
def test_upsample_with_scale_factor(target, dev):
77+
"""
78+
The Upsample module can be used with the size arugment or the scale
79+
factor argument but not both. This tests the latter.
80+
"""
81+
batch_size = 2
82+
channels = 3
83+
height, width = 32, 32
84+
85+
torch_module = Upsample(
86+
size=None, scale_factor=7, mode="nearest", align_corners=None, recompute_scale_factor=True
87+
)
88+
89+
raw_data = np.random.rand(batch_size, channels, height, width).astype("float32")
90+
91+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)
92+
93+
6094
@tvm.testing.parametrize_targets("cuda")
6195
def test_linalg_vector_norm(target, dev):
6296
class VectorNorm0(torch.nn.Module):

0 commit comments

Comments
 (0)