Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,29 @@ def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)

def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
x = self.env[node.args[0]]
size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None)
align_corners = (
node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", None)
)
if size is not None:
scale_factor = None
else:
scale_arg = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", 1)
if isinstance(scale_arg, (list, tuple)):
scale_factor = scale_arg[0]
else:
scale_factor = scale_arg

return self._upsample_impl(
x,
size=size,
scale_factor=scale_factor,
method="cubic",
align_corners=align_corners,
)

########## Manipulation ##########

def _narrow(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -426,6 +449,7 @@ def create_convert_map(
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
"upsample_nearest2d.vec": self._upsample_nearest2d,
"upsample_bicubic2d.vec": self._upsample_bicubic2d,
# statistical
"mean.dim": self._mean,
"prod.default": self._prod,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def resize2d(
method: str = "linear",
coordinate_transformation_mode: str = "half_pixel",
rounding_method: str = "round",
cubic_alpha: float = -0.5,
cubic_alpha: float = -0.75,
cubic_exclude: int = 0,
extrapolation_value: float = 0.0,
out_dtype: Optional[Union[str, DataType]] = None,
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def resize1d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
bicubic_alpha=-0.5,
bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
Expand Down Expand Up @@ -748,7 +748,7 @@ def resize2d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
bicubic_alpha=-0.5,
bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def resize3d(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
bicubic_alpha=-0.5,
bicubic_alpha=-0.75,
bicubic_exclude=0,
extrapolation_value=0.0,
out_dtype=None,
Expand Down
34 changes: 32 additions & 2 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,7 +2935,7 @@ def main(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
out_dtype="void",
Expand Down Expand Up @@ -2964,7 +2964,36 @@ def main(
method="nearest_neighbor",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
out_dtype="void",
)
gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,)
R.output(gv)
return gv

class InterpolateBicubic(Module):
def forward(self, input):
return torch.nn.functional.interpolate(input, (224, 224), mode="bicubic")

@tvm.script.ir_module
class expected_bicubic:
@R.function
def main(
input: R.Tensor((1, 3, 112, 112), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d(
input,
R.shape([224, 224]),
roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)],
layout="NCHW",
method="cubic",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0.0,
out_dtype="void",
Expand All @@ -2976,6 +3005,7 @@ def main(
example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic)


def test_mean():
Expand Down
43 changes: 40 additions & 3 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,7 +3250,7 @@ def main(
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
Expand Down Expand Up @@ -3287,7 +3287,7 @@ def main(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
Expand Down Expand Up @@ -3324,7 +3324,7 @@ def main(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
Expand All @@ -3335,6 +3335,43 @@ def main(

verify_model(Interpolate3(), input_info, {}, expected3)

class Interpolate4(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0, 1.0),
mode="bicubic",
align_corners=False,
)

@tvm.script.ir_module
class expected4:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 20, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 20, 10), dtype="float32") = R.image.resize2d(
input_1,
(20, 10),
roi=[0.000000, 0.000000, 0.000000, 0.000000],
layout="NCHW",
method="cubic",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 20, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate4(), input_info, {}, expected4)


def test_addmm():
input_info = [
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test(
method="nearest_neighbor",
coordinate_transformation_mode="asymmetric",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform_convert_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ def main(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",
Expand Down Expand Up @@ -1477,7 +1477,7 @@ def main(
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="round",
cubic_alpha=-0.5,
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="void",
Expand Down