Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""Base class for PyTorch FX Graph importer."""
import abc
import math
from typing import Callable, Dict, Optional, Tuple, Union

from tvm import relax
Expand Down Expand Up @@ -141,19 +142,94 @@ def _celu(self, node: fx.Node) -> relax.Var:

def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
a_max = args[2] if len(args) > 2 else node.kwargs["max"]
x = args[0]
a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)

a_min = -math.inf if a_min is None else a_min
a_max = math.inf if a_max is None else a_max

# Handle the case where a_min is a tensor
if not isinstance(a_min, (int, float)):
raise ValueError(
f"TVM only supports constant min value for torch.clamp/clip, "
f"but got {a_min} with type {type(a_min)}"
from torch import fx

if isinstance(a_min, fx.Node):
# Extract relax Expr (needed for fx.tracer)
a_min = self.env[a_min]
assert isinstance(a_min, relax.Expr), (
f"Unexpected argument type "
f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
)
a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x)))
x = self.block_builder.emit(relax.op.maximum(x, a_min))
a_min = -math.inf

# Handle the case where a_max is a tensor
if not isinstance(a_max, (int, float)):
raise ValueError(
f"TVM only supports constant max value for torch.clamp/clip, "
f"but got {a_max} with type {type(a_max)}"
from torch import fx

if isinstance(a_max, fx.Node):
# Extract relax Expr (needed for fx.tracer)
a_max = self.env[a_max]
assert isinstance(a_max, relax.Expr), (
f"Unexpected argument type "
f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
)
a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x)))
x = self.block_builder.emit(relax.op.minimum(x, a_max))
a_max = math.inf

return self.block_builder.emit(relax.op.clip(x, a_min, a_max))

def _clamp_min(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
x = args[0]
a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
a_max = math.inf

a_min = -math.inf if a_min is None else a_min

# Handle the case where a_min is a tensor
if not isinstance(a_min, (int, float)):
from torch import fx

if isinstance(a_min, fx.Node):
# Extract relax Expr (needed for fx.tracer)
a_min = self.env[a_min]
assert isinstance(a_min, relax.Expr), (
f"Unexpected argument type "
f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
)
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
a_min = self.block_builder.emit(relax.op.broadcast_to(a_min, self.shape_of(x)))
x = self.block_builder.emit(relax.op.maximum(x, a_min))
a_min = -math.inf

return self.block_builder.emit(relax.op.clip(x, a_min, a_max))

def _clamp_max(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
x = args[0]
a_min = -math.inf
a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)

a_max = math.inf if a_max is None else a_max

# Handle the case where a_max is a tensor
if not isinstance(a_max, (int, float)):
from torch import fx

if isinstance(a_max, fx.Node):
# Extract relax Expr (needed for fx.tracer)
a_max = self.env[a_max]
assert isinstance(a_max, relax.Expr), (
f"Unexpected argument type "
f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
)
a_max = self.block_builder.emit(relax.op.broadcast_to(a_max, self.shape_of(x)))
x = self.block_builder.emit(relax.op.minimum(x, a_max))
a_max = math.inf

return self.block_builder.emit(relax.op.clip(x, a_min, a_max))

def _elu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down Expand Up @@ -696,8 +772,8 @@ def _embedding_impl(
return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size]))

def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var:
from torch.fx.immutable_collections import immutable_list
import numpy as np # type: ignore
from torch.fx.immutable_collections import immutable_list

if isinstance(normalized_shape, (immutable_list, tuple)):
normalized_shape = tuple(normalized_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def create_convert_map(
"bitwise_not.default": self._unary_op(relax.op.bitwise_not),
"ceil.default": self._unary_op(relax.op.ceil),
"clamp.default": self._clamp,
"clamp_min.default": self._clamp_min,
"clamp_max.default": self._clamp_max,
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
Expand Down Expand Up @@ -294,6 +296,7 @@ def create_convert_map(
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat.default": self._cat,
"clamp.Tensor": self._clamp,
"concat.default": self._cat,
"copy_.default": self._copy_,
"cumsum.default": self._cumsum,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck
# pylint: disable=import-outside-toplevel
"""PyTorch FX frontend of Relax."""
from typing import Callable, Dict, List, Tuple, Union
from functools import partial, reduce
from typing import Callable, Dict, List, Tuple, Union

import tvm
from tvm import relax
Expand Down Expand Up @@ -598,6 +598,7 @@ def create_convert_map(
self,
) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
import operator

from torch import nn

return {
Expand Down
81 changes: 81 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,87 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_clamp(target, dev):
class ClampBothTensor(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("min_val", torch.tensor(-1.0))
self.register_buffer("max_val", torch.tensor(1.0))

def forward(self, x):
return x.clamp(min=self.min_val, max=self.max_val)

class ClampBothInt(torch.nn.Module):
def __init__(self):
super().__init__()
self.min_val = -1
self.max_val = 1

def forward(self, x):
return x.clamp(min=self.min_val, max=self.max_val)

class ClampMinOnlyTensor(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("min_val", torch.tensor(0.0))

def forward(self, x):
return x.clamp(min=self.min_val)

class ClampMinOnlyInt(torch.nn.Module):
def __init__(self):
super().__init__()
self.min_val = 0

def forward(self, x):
return x.clamp(min=self.min_val)

class ClampMaxOnlyTensor(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("max_val", torch.tensor(0.5))

def forward(self, x):
return x.clamp(max=self.max_val)

class ClampMaxOnlyInt(torch.nn.Module):
def __init__(self):
super().__init__()
self.max_val = 0.5

def forward(self, x):
return x.clamp(max=self.max_val)

class ClampDifferentValues(torch.nn.Module):
def __init__(self):
super().__init__()
self.min_val = -2
self.max_val = 2

def forward(self, x):
return x.clamp(min=self.min_val, max=self.max_val)

# Create random data with values outside our clamp ranges
raw_data = np.random.uniform(-3.0, 3.0, (2, 3, 4, 5)).astype(np.float32)

torch_module0 = ClampBothTensor().eval()
torch_module1 = ClampBothInt().eval()
torch_module2 = ClampMinOnlyTensor().eval()
torch_module3 = ClampMinOnlyInt().eval()
torch_module4 = ClampMaxOnlyTensor().eval()
torch_module5 = ClampMaxOnlyInt().eval()
torch_module6 = ClampDifferentValues().eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module4, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module5, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module6, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_expand_as(target, dev):
class ExpandAs0(torch.nn.Module):
Expand Down
62 changes: 59 additions & 3 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,70 @@ def forward(self, input):
class expected_clamp:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5)
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
input,
R.prim_value(T.float64(0.10000000000000001)),
R.prim_value(T.float64(0.5)),
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv

verify_model(Clamp(), example_args, {}, expected_clamp)

class ClampMinOnly(Module):
def forward(self, input):
return torch.clamp(input, min=0.5, max=None)

@tvm.script.ir_module
class expected_clamp_min_only:
@R.function
def main(
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
input, R.prim_value(T.float64(0.5)), R.prim_value(T.float64("inf"))
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv

verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)

class ClampTensors(Module):
def forward(self, input):
return torch.clamp(input, min=input, max=input)

@tvm.script.ir_module
class expected_clamp_tensors:
@R.function
def main(
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
input, R.shape([1, 3, 10, 10])
)
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(input, lv)
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
input, R.shape([1, 3, 10, 10])
)
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2)
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf"))
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
R.output(gv)
return gv

verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)

# dropout

class Dropout1(Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -3248,3 +3300,7 @@ def main(
exported_program = export(Identity(), args=example_args)
mod = from_exported_program(exported_program, no_bind_return_tuple=True)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading