diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a0f00e1f4b9d..6bbc9d5de618 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -19,6 +19,7 @@ # pylint: disable=import-outside-toplevel """Base class for PyTorch FX Graph importer.""" import abc +import math from typing import Callable, Dict, Optional, Tuple, Union from tvm import relax @@ -141,19 +142,94 @@ def _celu(self, node: fx.Node) -> relax.Var: def _clamp(self, node: fx.Node) -> relax.Expr: args = self.retrieve_args(node) - a_min = args[1] if len(args) > 1 else node.kwargs["min"] - a_max = args[2] if len(args) > 2 else node.kwargs["max"] + x = args[0] + a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf) + a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf) + + a_min = -math.inf if a_min is None else a_min + a_max = math.inf if a_max is None else a_max + + # Handle the case where a_min is a tensor if not isinstance(a_min, (int, float)): - raise ValueError( - f"TVM only supports constant min value for torch.clamp/clip, " - f"but got {a_min} with type {type(a_min)}" + from torch import fx + + if isinstance(a_min, fx.Node): + # Extract relax Expr (needed for fx.tracer) + a_min = self.env[a_min] + assert isinstance(a_min, relax.Expr), ( + f"Unexpected argument type " + f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}" ) + a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x))) + x = self.block_builder.emit(relax.op.maximum(x, a_min)) + a_min = -math.inf + + # Handle the case where a_max is a tensor if not isinstance(a_max, (int, float)): - raise ValueError( - f"TVM only supports constant max value for torch.clamp/clip, " - f"but got {a_max} with type {type(a_max)}" + from torch import fx + + if isinstance(a_max, fx.Node): + # Extract relax Expr (needed for fx.tracer) + a_max = self.env[a_max] + assert isinstance(a_max, relax.Expr), ( + f"Unexpected argument type " + f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}" + ) + a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x))) + x = self.block_builder.emit(relax.op.minimum(x, a_max)) + a_max = math.inf + + return self.block_builder.emit(relax.op.clip(x, a_min, a_max)) + + def _clamp_min(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf) + a_max = math.inf + + a_min = -math.inf if a_min is None else a_min + + # Handle the case where a_min is a tensor + if not isinstance(a_min, (int, float)): + from torch import fx + + if isinstance(a_min, fx.Node): + # Extract relax Expr (needed for fx.tracer) + a_min = self.env[a_min] + assert isinstance(a_min, relax.Expr), ( + f"Unexpected argument type " + f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}" ) - return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x))) + x = self.block_builder.emit(relax.op.maximum(x, a_min)) + a_min = -math.inf + + return self.block_builder.emit(relax.op.clip(x, a_min, a_max)) + + def _clamp_max(self, node: fx.Node) -> relax.Expr: + args = self.retrieve_args(node) + x = args[0] + a_min = -math.inf + a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf) + + a_max = math.inf if a_max is None else a_max + + # Handle the case where a_max is a tensor + if not isinstance(a_max, (int, float)): + from torch import fx + + if isinstance(a_max, fx.Node): + # Extract relax Expr (needed for fx.tracer) + a_max = self.env[a_max] + assert isinstance(a_max, relax.Expr), ( + f"Unexpected argument type " + f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}" + ) + a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x))) + x = self.block_builder.emit(relax.op.minimum(x, a_max)) + a_max = math.inf + + return self.block_builder.emit(relax.op.clip(x, a_min, a_max)) def _elu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -696,8 +772,8 @@ def _embedding_impl( return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: - from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore + from torch.fx.immutable_collections import immutable_list if isinstance(normalized_shape, (immutable_list, tuple)): normalized_shape = tuple(normalized_shape) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2103365c6c60..71a3d13aa1e4 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -193,6 +193,8 @@ def create_convert_map( "bitwise_not.default": self._unary_op(relax.op.bitwise_not), "ceil.default": self._unary_op(relax.op.ceil), "clamp.default": self._clamp, + "clamp_min.default": self._clamp_min, + "clamp_max.default": self._clamp_max, "cos.default": self._unary_op(relax.op.cos), "cosh.default": self._unary_op(relax.op.cosh), "dropout.default": lambda node: self.env[node.args[0]], @@ -294,6 +296,7 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "cat.default": self._cat, + "clamp.Tensor": self._clamp, "concat.default": self._cat, "copy_.default": self._copy_, "cumsum.default": self._cumsum, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index abda5088db4d..952fb6f97111 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,8 +18,8 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Tuple, Union from functools import partial, reduce +from typing import Callable, Dict, List, Tuple, Union import tvm from tvm import relax @@ -598,6 +598,7 @@ def create_convert_map( self, ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]: import operator + from torch import nn return { diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e8b5da0dc2ab..6cc12370d648 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -56,6 +56,87 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_tensor_clamp(target, dev): + class ClampBothTensor(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("min_val", torch.tensor(-1.0)) + self.register_buffer("max_val", torch.tensor(1.0)) + + def forward(self, x): + return x.clamp(min=self.min_val, max=self.max_val) + + class ClampBothInt(torch.nn.Module): + def __init__(self): + super().__init__() + self.min_val = -1 + self.max_val = 1 + + def forward(self, x): + return x.clamp(min=self.min_val, max=self.max_val) + + class ClampMinOnlyTensor(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("min_val", torch.tensor(0.0)) + + def forward(self, x): + return x.clamp(min=self.min_val) + + class ClampMinOnlyInt(torch.nn.Module): + def __init__(self): + super().__init__() + self.min_val = 0 + + def forward(self, x): + return x.clamp(min=self.min_val) + + class ClampMaxOnlyTensor(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("max_val", torch.tensor(0.5)) + + def forward(self, x): + return x.clamp(max=self.max_val) + + class ClampMaxOnlyInt(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_val = 0.5 + + def forward(self, x): + return x.clamp(max=self.max_val) + + class ClampDifferentValues(torch.nn.Module): + def __init__(self): + super().__init__() + self.min_val = -2 + self.max_val = 2 + + def forward(self, x): + return x.clamp(min=self.min_val, max=self.max_val) + + # Create random data with values outside our clamp ranges + raw_data = np.random.uniform(-3.0, 3.0, (2, 3, 4, 5)).astype(np.float32) + + torch_module0 = ClampBothTensor().eval() + torch_module1 = ClampBothInt().eval() + torch_module2 = ClampMinOnlyTensor().eval() + torch_module3 = ClampMinOnlyInt().eval() + torch_module4 = ClampMaxOnlyTensor().eval() + torch_module5 = ClampMaxOnlyInt().eval() + torch_module6 = ClampDifferentValues().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module4, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module5, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module6, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_expand_as(target, dev): class ExpandAs0(torch.nn.Module): diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6406610bf53e..8b0a711a52ee 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -135,18 +135,70 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + input, + R.prim_value(T.float64(0.10000000000000001)), + R.prim_value(T.float64(0.5)), + ) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv verify_model(Clamp(), example_args, {}, expected_clamp) + class ClampMinOnly(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) + + @tvm.script.ir_module + class expected_clamp_min_only: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + input, R.prim_value(T.float64(0.5)), R.prim_value(T.float64("inf")) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only) + + class ClampTensors(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) + + @tvm.script.ir_module + class expected_clamp_tensors: + @R.function + def main( + input: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( + input, R.shape([1, 3, 10, 10]) + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(input, lv) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( + input, R.shape([1, 3, 10, 10]) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf")) + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors) + # dropout + class Dropout1(Module): def __init__(self): super().__init__() @@ -3248,3 +3300,7 @@ def main( exported_program = export(Identity(), args=example_args) mod = from_exported_program(exported_program, no_bind_return_tuple=True) tvm.ir.assert_structural_equal(mod, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 020fc8f5b3c2..fbea8b7388ed 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import operator import pytest import torch @@ -21,6 +22,7 @@ from torch import fx from torch.nn import Module import torchvision +import math import tvm from tvm import relax @@ -1970,7 +1972,7 @@ def forward(self, input): class expected_clamp: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): # block 0 with R.dataflow(): @@ -1981,29 +1983,53 @@ def main( verify_model(Clamp(), input_info, {}, expected_clamp) - from tvm.relax.frontend.torch import from_fx - - with pytest.raises( - ValueError, match="TVM only supports constant max value for torch.clamp/clip" - ): + class ClampMinOnly(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) - class Clamp_Error(Module): - def forward(self, input): - return torch.clamp(input, min=0.5, max=None) + @tvm.script.ir_module + class expected_clamp_min_only: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.5, math.inf) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv - gm = fx.symbolic_trace(Clamp_Error()) - from_fx(gm, input_info) + verify_model(ClampMinOnly(), input_info, {}, expected_clamp_min_only) - with pytest.raises( - ValueError, match="TVM only supports constant min value for torch.clamp/clip" - ): + class ClampTensors(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) - class Clamp_Error(Module): - def forward(self, input): - return torch.clamp(input, min=input, max=input) + @tvm.script.ir_module + class expected_clamp_tensors: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( + inp_0, R.shape([1, 3, 10, 10]) + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(inp_0, lv) + lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( + inp_0, R.shape([1, 3, 10, 10]) + ) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2) + lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( + lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf")) + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv4 + R.output(gv) + return gv - gm = fx.symbolic_trace(Clamp_Error()) - from_fx(gm, input_info) + verify_model(ClampTensors(), input_info, {}, expected_clamp_tensors) # dropout class Dropout1(Module):