From 93c8329277a401fc391e376c54abc0312941af6c Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Fri, 27 Jan 2023 15:18:19 -0500 Subject: [PATCH 1/4] [Relax][frontend] torch fx importer Implements the Relax importer from PyTorch, using torch FX. An example use of the importer is: ```python from tvm.relax.frontend import from_pytorch class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) def forward(self, input): return self.linear(input) torch_model = MyModule() input_info = {"input_1": ((128, 10), "float32")} mod: tvm.IRModule = from_pytorch(torch_model, input_info) ``` --------- Co-authored-by: Ruihang Lai --- python/tvm/relax/block_builder.py | 2 +- python/tvm/relax/frontend/__init__.py | 19 + python/tvm/relax/frontend/torch/__init__.py | 20 + .../tvm/relax/frontend/torch/fx_translator.py | 825 ++++++++ python/tvm/relax/transform/legalize_ops.py | 865 +++++++++ tests/python/relax/test_frontend_from_fx.py | 1716 +++++++++++++++++ 6 files changed, 3446 insertions(+), 1 deletion(-) create mode 100644 python/tvm/relax/frontend/__init__.py create mode 100644 python/tvm/relax/frontend/torch/__init__.py create mode 100644 python/tvm/relax/frontend/torch/fx_translator.py create mode 100644 python/tvm/relax/transform/legalize_ops.py create mode 100644 tests/python/relax/test_frontend_from_fx.py diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 1b63e7deed..7eae5ec889 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -601,7 +601,7 @@ def match_cast(self, value: Expr, struct_info: StructInfo) -> Var: """ return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore - def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None: + def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> Var: """Emit output for the current dataflow block or function. Parameters diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py new file mode 100644 index 0000000000..6c9c188aaa --- /dev/null +++ b/python/tvm/relax/frontend/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frontends for constructing Relax programs, with the model importers +""" diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py new file mode 100644 index 0000000000..1eb4bc0e8c --- /dev/null +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +PyTorch Frontends for constructing Relax programs, with the model importers +""" +from .fx_translator import from_fx diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py new file mode 100644 index 0000000000..1af821c849 --- /dev/null +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -0,0 +1,825 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch FX frontend of Relax.""" +from typing import Callable, Dict, Mapping, Tuple, Union, List +from functools import reduce + +import tvm +from tvm import relax + + +class TorchFXImporter: + """An importer from PyTorch FX to Relax.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.node.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Constant] = {} + self.params_transpose: Dict[torch.Tensor, relax.Constant] = {} + self.named_modules: Dict[str, torch.Module] = None + self.block_builder: relax.BlockBuilder = None + self.create_convert_map() + + ########## Utilities ########## + @staticmethod + def _fetch_attr(model, target: str): + import torch # type: ignore + + target_atoms = target.split(".") + attr_itr = model + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced non existing target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return attr_itr + + @staticmethod + def _convert_data_type(input_type): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + input_type = input_type.lower() + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + shape = tensor.data.shape + dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, dtype)) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.node.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + @staticmethod + def _promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def _call_binary_op(self, op, lhs, rhs): + lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + ########## Arithmetic ########## + + def _cos(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) + + def _sin(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) + + def _sqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.sqrt(arg)) + + def _add(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.add, lhs, rhs) + return lhs + rhs + + def _floordiv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.floor_divide, lhs, rhs) + return lhs // rhs + + def _mul(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.multiply, lhs, rhs) + return lhs * rhs + + def _sub(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.subtract, lhs, rhs) + return lhs - rhs + + def _truediv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.divide, lhs, rhs) + return lhs / rhs + + def _clamp(self, node: fx.node.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = node.kwargs["min"] + a_max = node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + ########## Compare ########## + + def _lt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less, lhs, rhs) + + ########## Creation ########## + + def _tril(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(relax.op.create.tril(x, k)) + + def _new_ones(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0])) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## DataType ########## + + def _float(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.astype(args[0], args[1])) + + ########## Linear Algebra ########## + + def _matmul_impl(self, a: relax.Expr, b: relax.Expr): + return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) + + def _matmul(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + res = self._matmul_impl( + args[0], + args[1], + ) + return res + + def _addmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + return self.block_builder.emit(relax.op.add(x, matmul)) + + ########## Manipulation ########## + + def _cat(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.concat(args[0], axis=node.kwargs["dim"])) + + def _expand(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + + def _flatten(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + + def _split(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _transpose(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + ########## Neural Network ########## + + def _linear(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + if module.weight not in self.params_transpose: + self.params_transpose[module.weight] = self._convert_torch_tensor_to_relax( + module.weight.T + ) + + weight_T = self.params_transpose[module.weight] + dense = self._matmul_impl(x, weight_T) + + if module.bias is None: + return dense + + bias = self.params[module.bias] + return self.block_builder.emit(relax.op.add(dense, bias)) + + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d + + bias = self.params[module.bias] + if len(bias.data.shape) == 1: + bias_data = bias.data.numpy().reshape(1, -1, 1, 1) + reshaped_bias = relax.const( + bias_data, relax.TensorStructInfo(bias_data.shape, bias.data.dtype) + ) + bias = self.params[module.bias] = reshaped_bias + + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + stride = node.args[2] if nargs > 2 else node.kwargs["stride"] + padding = node.args[3] if nargs > 3 else node.kwargs["padding"] + dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] + ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _adaptive_avg_pool2d(self, node: fx.node.Node) -> relax.Var: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, module.output_size, layout="NCHW") + ) + + def _softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + dtype = self._convert_data_type(str(module.running_mean.dtype)) + running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) + running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + + def _layer_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.checked_type) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.checked_type) + dim_num = len(module.normalized_shape) + axes = list(range(-dim_num, 0)) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=module.eps, + ) + ) + + def _group_norm(self, node: fx.node.Node) -> relax.Var: + # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, + # affine=True, device=None, dtype=None) + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + num_channels = module.num_channels + eps = module.eps + affine = module.affine + + shape = self.shape_of(x) + assert len(shape) == 4 + N, C, H, W = shape[0], shape[1], shape[2], shape[3] + assert C == num_channels + assert C % num_groups == 0 + grouped_x = self.block_builder.emit( + relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) + ) + mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) + sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) + square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) + sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) + var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) + var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) + std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) + norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) + + if affine: + weight = self.params[module.weight] + bias = self.params[module.bias] + weight_reshape = self.block_builder.emit( + relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) + ) + bias_reshape = self.block_builder.emit( + relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + ) + norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) + norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) + return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + + def _embedding(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + x = self.block_builder.emit(relax.op.astype(x, "int32")) + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + + def _interpolate(self, node: fx.node.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = node.kwargs["size"] + scale_factor = node.kwargs["scale_factor"] + method = node.kwargs["mode"] + align_corners = node.kwargs["align_corners"] + recompute_scale_factor = node.kwargs["recompute_scale_factor"] + antialias = node.kwargs["antialias"] + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + ########## Others ########## + + def _size(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value + + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Expr): + if node.args[1] == "dtype": + return self.env[node.args[0]].struct_info.dtype + elif node.args[1] == "shape": + return self.shape_of(self.env[node.args[0]]) + return getattr(self.env[node.args[0]], node.args[1]) + + def _getitem(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + begin = [] + end = [] + stride = [] + axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + for index in node.args[1]: + if isinstance(index, int): + begin.append(index) + end.append(index + 1) + stride.append(1) + axes.append(i) + i = i + 1 + elif isinstance(index, slice): + begin.append(0 if index.start is None else index.start) + end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + begin.append(0) + end.append(shape[i]) + axes.append(i) + i = i + 1 + sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + else: + assert False + + def create_convert_map(self): + from torch import nn + + self.convert_map = { + # call_module + nn.Linear: self._linear, + nn.Conv2d: self._conv2d, + nn.MaxPool2d: self._max_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d, + nn.Softmax: self._softmax, + nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Flatten: self._flatten, + nn.BatchNorm2d: self._batch_norm_2d, + nn.LayerNorm: self._layer_norm, + nn.GroupNorm: self._group_norm, + nn.Dropout: lambda node: self.env[node.args[0]], + nn.modules.sparse.Embedding: self._embedding, + # call_function and call_method + "cos": self._cos, + "sin": self._sin, + "add": self._add, + "floordiv": self._floordiv, + "mul": self._mul, + "sub": self._sub, + "sqrt": self._sqrt, + "lt": self._lt, + "truediv": self._truediv, + "new_ones": self._new_ones, + "tril": self._tril, + "sum": self._sum, + "float": self._float, + "half": self._half, + "type": self._type, + "matmul": self._matmul, + "addmm": self._addmm, + "cat": self._cat, + "expand": self._expand, + "flatten": self._flatten, + "permute": self._permute, + "reshape": self._reshape, + "split": self._split, + "transpose": self._transpose, + "unsqueeze": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), + "view": self._reshape, + "softmax": self._softmax, + "clamp": self._clamp, + "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + "gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), + "interpolate": self._interpolate, + "size": self._size, + "getattr": self._getattr, + "getitem": self._getitem, + "contiguous": lambda node: self.env[node.args[0]], + } + + def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : fx.GraphModule + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + Returns + ------- + module : tvm.IRModule + The converted Relax program. + + Examples + -------- + Users can use the FX tracer or dynamo.export() to extract + a fx.GraphModule from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_fx + from torch import _dynamo as dynamo + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the dynamo.export() to export the PyTorch model to FX. + try: + graph_module = dynamo.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to FX.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_pytorch(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + fx.symbolic_trace(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. + """ + from torch import fx + + self.named_modules = dict(model.named_modules()) + + graph: fx.Graph = model.graph + # Create input variables. + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + + # Initialize the block builder with a function and a dataflow block. + self.block_builder = relax.BlockBuilder() + with self.block_builder.function(name="main", params=inputs.copy()): + output = None + with self.block_builder.dataflow(): + # Translate model parameters. + for _, param in model.named_parameters(): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + if dtype in ("float32", "float16"): + self.params[param] = relax.const( + param.data.cpu().numpy(), relax.TensorStructInfo(shape, dtype) + ) + else: + raise ValueError("Unsupported data type for model parameters: %s" % dtype) + # Translate the model. + for node in graph.nodes: + if node.op == "placeholder": + assert len(inputs) > 0, "Provided inputs is less than actual inputs" + self.env[node] = inputs.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = TorchFXImporter._fetch_attr(model, node.target) + elif node.op == "call_module": + module = self.named_modules[node.target] + assert ( + type(module) in self.convert_map + ), f"Unsupported module type {type(module)}" + self.env[node] = self.convert_map[type(module)](node) + elif node.op == "call_function": + func_name = node.name.rstrip("0123456789_") + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + elif node.op == "call_method": + assert ( + node.target in self.convert_map + ), f"Unsupported function target {node.target}" + self.env[node] = self.convert_map[node.target](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + return self.block_builder.get() + + +def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """The public interface of PyTorch FX importer for Relax. + See `TorchFXImporter.from_fx` for full documentation. + """ + return TorchFXImporter().from_fx(model, input_info) diff --git a/python/tvm/relax/transform/legalize_ops.py b/python/tvm/relax/transform/legalize_ops.py new file mode 100644 index 0000000000..420445d39c --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops.py @@ -0,0 +1,865 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=abstract-method,invalid-name,missing-class-docstring,missing-function-docstring,missing-module-docstring,unused-argument +import logging +from typing import Callable, Dict, List, Optional, Union + +import tvm +from tvm import te, tir, topi, relax +from tvm.relax import struct_info +from tvm.ir.module import IRModule + +from ..analysis import remove_all_unused +from ..expr import Call, Constant, Expr, Function, ShapeExpr, Tuple, TupleGetItem, Var +from ..expr_functor import mutator, PyExprMutator +from ..block_builder import BlockBuilder + + +##################### Commons ##################### + +# The function type of a TE function, which accepts TE Tensors and +# other attributes, and returns the output TE Tensor. +TEFunc = Callable[..., te.Tensor] + +# The function type of a legalization function, which takes a +# BlockBuilder and the Call to be legalized, and outputs the legalization +# result Expr. +LegalizeFunc = Callable[[BlockBuilder, Call], Expr] + + +def has_known_shape_value(sinfo: struct_info.StructInfo) -> bool: + """Check if a given Tensor/Shape/TupleStructInfo contains + shapes whose values are all known. + + Parameters + ---------- + sinfo : struct_info.StructInfo + The struct info to be checked. + + Returns + ------- + ret : bool + A boolean indicating if the given struct info contains shape + values that are all known. + """ + if isinstance(sinfo, struct_info.TensorStructInfo): + return isinstance(sinfo.shape, ShapeExpr) + elif isinstance(sinfo, struct_info.ShapeStructInfo): + return sinfo.values is not None + elif isinstance(sinfo, struct_info.TupleStructInfo): + return all([has_known_shape_value(field_sinfo) for field_sinfo in sinfo.fields]) + elif isinstance(sinfo, struct_info.PrimStructInfo): + return True + else: + return False + + +def try_convert_to_scalar_const(expr: Expr) -> Union[Expr, bool, float, int]: + """Check if the input Expr is a scalar constant. + If it is, return its plain value. + If it is not, return the input expr. + + Parameters + ---------- + expr : Expr + The expr to be checked and converted. + + Returns + --–---- + ret : Union[Expr, bool, float, int] + Return a Python native value (int/float/bool) if the given + expr is a scalar constant. Or return the input itself + if it is not. + """ + if isinstance(expr, Constant) and expr.struct_info.ndim == 0: + return expr.data.numpy()[()].item() + else: + return expr + + +def _unary(te_func: TEFunc) -> LegalizeFunc: + def unary_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(te_func, call.args[0]) + + return unary_call_te + + +def _binary(te_func: TEFunc) -> LegalizeFunc: + def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: + # To simplify the created PrimFunc, we first check if arg1 is a constant scalar. + # If it is not, we then check if arg0 is a constant scalar. + arg0 = call.args[0] + arg1 = try_convert_to_scalar_const(call.args[1]) + if isinstance(arg1, Expr): # type: ignore + arg0 = try_convert_to_scalar_const(arg0) + return bb.call_te(te_func, arg0, arg1) + + return binary_call_te + + +##################### Creation ##################### + + +def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> LegalizeFunc: + def full_call_te(bb: BlockBuilder, call: Call) -> Expr: + _fill_value = ( + try_convert_to_scalar_const(call.args[1]) if fill_value is None else fill_value + ) + + return bb.call_te( + topi.full, + call.args[0].struct_info.shape if is_like else call.args[0], + call.struct_info.dtype, + _fill_value, + primfunc_name_hint=primfunc_name, + ) + + return full_call_te + + +def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc: + def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.trilu, + call.args[0], + tir.const(call.attrs.k, "int32"), + upper=is_upper, + primfunc_name_hint=primfunc_name, + ) + + return tril_triu_call_te + + +##################### Datatype ##################### + + +def _astype(bb: BlockBuilder, call: Call) -> Expr: + arg = try_convert_to_scalar_const(call.args[0]) + if isinstance(arg, Expr): # type: ignore + return bb.call_te(topi.cast, arg, call.attrs.dtype) + else: + return relax.const(arg, call.attrs.dtype) + + +##################### Indexing ##################### + + +def _take(bb: BlockBuilder, call: Call) -> Expr: + # Currently Relax `take` operator doesn't provide the mode choices and + # requires input indices to be in range. + # We use fast mode, which leads to runtime error whenever some index is + # out of bound. + return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast") + + +def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: + if not all( + [ + isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm) + for i in call.attrs.axes + ] + ): + logging.info( + "Cases where an axis with symbolic length is sliced are not able " + "to be legalized through TOPI" + ) + return call + + return bb.call_te( + topi.strided_slice, + call.args[0], + call.attrs.begin, + call.attrs.end, + call.attrs.strides, + call.attrs.axes, + slice_mode="end", + ) + + +##################### Linear algebra ##################### + + +def _matmul(bb: BlockBuilder, call: Call) -> Expr: + def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + b_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + + dtype = call.attrs.out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + else: + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="matmul", + ) + + return bb.call_te(te_matmul, call.args[0], call.args[1], primfunc_name_hint="matmul") + + +##################### Manipulation ##################### + + +def _reshape( + te_func: TEFunc, primfunc_name: str, is_collapse_sum_like: bool = False +) -> LegalizeFunc: + def reshape_call_te(bb: BlockBuilder, call: Call): + tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] + return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name) + + return reshape_call_te + + +def _concat(bb: BlockBuilder, call: Call) -> Expr: + t = call.args[0] + n_field = len(t.struct_info.fields) + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + return bb.call_te( + topi.concatenate, fields, None if call.attrs.axis is None else call.attrs.axis.value + ) + + +def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: + def te_expand_dims(data, axis): + data_relax = relax.Var("data", relax.TensorStructInfo(data.shape)) + f_infer_sinfo = call.op.get_attr("FInferStructInfo") + output_shape = f_infer_sinfo(relax.op.expand_dims(data_relax, axis), bb).shape + output_ndim = len(output_shape) + + data_dims = [] + for i in range(output_ndim): + if i not in axis and (i - output_ndim) not in axis: + data_dims.append(i) + return te.compute( + output_shape, + lambda *idx: data(*[idx[dim] for dim in data_dims]), + name="expand_dims", + ) + + return bb.call_te( + te_expand_dims, call.args[0], call.attrs.axis, primfunc_name_hint="expand_dims" + ) + + +def _flatten(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) + + +def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.transpose, call.args[0], call.attrs.axes) + + +def _split(bb: BlockBuilder, call: Call) -> Expr: + if isinstance(call.attrs.indices_or_sections, tir.IntImm): + indices_or_sections = call.attrs.indices_or_sections.value + modulo = tvm.arith.Analyzer().simplify( + call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections + ) + if modulo != 0: + logging.info( + "Split cannot be legalized by TOPI when the axis being split has " + "length that not divisible by the input number of section." + ) + return call + else: + indices_or_sections = call.attrs.indices_or_sections + return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) + + +def _squeeze(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis) + + +##################### Search ##################### + + +def _where(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.where, call.args[0], call.args[1], call.args[2]) + + +##################### Statistical ##################### + + +def _statistical(te_func: TEFunc) -> LegalizeFunc: + def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) + + return statistical_call_te + + +def _compute_shape_prod(x: te.Tensor, axis: List[tir.IntImm]) -> tir.PrimExpr: + shape_prod = tir.const(1, "int32") + axes = [_axis.value for _axis in axis] if axis is not None else range(0, len(x.shape)) + for dim in axes: + shape_prod = shape_prod * x.shape[dim] + return shape_prod + + +def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + shape_prod = _compute_shape_prod(x, axis) + res_sum = topi.sum(x, axis, keepdims) + return topi.divide(res_sum, shape_prod) + + +def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + dev = x - _te_mean(x, axis, keepdims) + return _te_mean(dev * dev, axis, keepdims) + + +def _mean(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_mean, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="mean" + ) + + +def _std(bb: BlockBuilder, call: Call) -> Expr: + def te_std(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: + return topi.sqrt(_te_variance(x, axis, keepdims)) + + return bb.call_te( + te_std, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="std" + ) + + +def _variance(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_variance, + call.args[0], + call.attrs.axis, + call.attrs.keepdims, + primfunc_name_hint="variance", + ) + + +##################### Neural network ##################### + + +def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.data_layout: + logging.info( + "TOPI conv2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + if len(call.attrs.data_layout) != 4 or len(call.attrs.kernel_layout) != 4: + logging.info( + "Conv2D where data layout or kernel layout have channel chunk " + "cannot be legalized by TOPI at this moment." + ) + return call + if call.attrs.groups != 1: + data_layout = tir.layout(call.attrs.data_layout) + kernel_layout = tir.layout(call.attrs.kernel_layout) + ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] + oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] + if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + logging.info( + "Conv2D where number of groups is more than one and input or output " + "channel size is symbolic cannot be legalized by TOPI at this moment." + ) + return call + + return bb.call_te( + topi.nn.conv, + inp=call.args[0], + filt=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + groups=call.attrs.groups, + data_layout=call.attrs.data_layout, + kernel_layout=call.attrs.kernel_layout, + out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, + primfunc_name_hint="conv2d", + ) + + +def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI max_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + return bb.call_te( + topi.nn.pool2d, + call.args[0], + kernel=call.attrs.pool_size, + stride=call.attrs.strides, + dilation=call.attrs.dilation, + padding=call.attrs.padding, + pool_type="max", + ceil_mode=call.attrs.ceil_mode, + layout=call.attrs.layout, + primfunc_name_hint="max_pool2d", + ) + + +def _nn_adaptive_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: + if call.attrs.out_layout != call.attrs.layout: + logging.info( + "TOPI adaptive_max_pool2d does not support different input-output " + "layouts, and thus cannot be legalized by TOPI" + ) + return call + + def te_adaptive_avg_pool2d(data, output_size, layout_str): + if output_size is None: + layout = tir.layout(layout_str) + idx_H = layout.index_of("H") + idx_W = layout.index_of("W") + assert idx_H != -1 and idx_W != -1 + output_size = (data.shape[idx_H], data.shape[idx_W]) + + return topi.nn.adaptive_pool(data, output_size, "avg", layout_str) + + return bb.call_te( + te_adaptive_avg_pool2d, + call.args[0], + call.attrs.output_size, + call.attrs.layout, + primfunc_name_hint="adaptive_avg_pool2d", + ) + + +def _nn_relu(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.relu, call.args[0]) + + +def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: + def gelu(x: te.Tensor): + dtype = x.dtype + return x * ( + tir.const(0.5, dtype) + + topi.erf(x * tir.const(0.5**0.5, dtype)) * tir.const(0.5, dtype) + ) + + return bb.call_te(gelu, call.args[0], primfunc_name_hint="gelu") + + +def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: + def te_silu(x: te.Tensor): + return topi.multiply(x, topi.sigmoid(x)) + + return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") + + +def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) + + +def _nn_log_softmax(bb: BlockBuilder, call: Call): + return bb.call_te(topi.nn.log_softmax, call.args[0], call.attrs.axis) + + +def _nn_cross_entropy_without_logits(bb: BlockBuilder, call: Call): + def te_cross_entropy_without_logits(x, y): + if len(x.shape) > 1: + return -topi.sum(topi.log(x) * y) / x.shape[0] + return -topi.sum(topi.log(x) * y) + + return bb.call_te( + te_cross_entropy_without_logits, + call.args[0], + call.args[1], + primfunc_name_hint="cross_entropy_without_logits", + ) + + +def _nn_cross_entropy_with_logits(bb: BlockBuilder, call: Call): + def te_cross_entropy_with_logits(x, y): + if len(x.shape) > 1: + return -topi.sum(x * y) / x.shape[0] + return -topi.sum(x * y) + + return bb.call_te( + te_cross_entropy_with_logits, + call.args[0], + call.args[1], + primfunc_name_hint="cross_entropy_with_logits", + ) + + +def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.batch_norm, + data=call.args[0], + gamma=call.args[1], + beta=call.args[2], + moving_mean=call.args[3], + moving_var=call.args[4], + axis=call.attrs.axis, + epsilon=call.attrs.epsilon, + center=call.attrs.center, + scale=call.attrs.scale, + ) + + +def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.layer_norm, + call.args[0], + call.args[1], + call.args[2], + axis=call.attrs.axes, + epsilon=call.attrs.epsilon, + ) + + +def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: + logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") + return call + + +def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr: + if len(call.args) == 2: + # TODO(relax-team): handle optional arugment weight of NLLLoss + logging.info( + "Can not legalize it now, because don't know how to set " + "the default value of the optional argument 'weight' of NLLLoss." + ) + return call + return bb.call_te( + topi.nn.nll_loss, + call.args[0], + call.args[1], + call.args[2], + reduction=call.attrs.reduction, + ignore_index=call.attrs.ignore_index, + ) + + +##################### Image ##################### + + +def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.image.resize2d, + call.args[0], + roi=call.attrs.roi, + size=call.args[1], + layout=call.attrs.layout, + method=call.attrs.method, + coordinate_transformation_mode=call.attrs.coordinate_transformation_mode, + rounding_method=call.attrs.rounding_method, + bicubic_alpha=call.attrs.cubic_alpha, + bicubic_exclude=call.attrs.cubic_exclude, + extrapolation_value=call.attrs.extrapolation_value, + ) + + +##################### Common ##################### + + +def _call_topi(te_func: TEFunc) -> LegalizeFunc: + return lambda bb, call: bb.call_te(te_func, *call.args) + + +########################################################## + + +# Todo(relax-team): Introduce cumsum for GPT-2 +# def _cumsum(bb: BlockBuilder, call: Call): +# return bb.call_te(topi.cumsum, args[0], attrs.axis) + + +DEFAULT_OP_LEGALIZE_MAP: Dict[str, LegalizeFunc] = { + # Arithmetic and comparison + "relax.abs": _unary(topi.abs), + "relax.cos": _unary(topi.cos), + "relax.log": _unary(topi.log), + "relax.exp": _unary(topi.exp), + "relax.negative": _unary(topi.negative), + "relax.sigmoid": _unary(topi.sigmoid), + "relax.sin": _unary(topi.sin), + "relax.sqrt": _unary(topi.sqrt), + "relax.tanh": _unary(topi.tanh), + "relax.clip": _call_topi(topi.clip), + "relax.add": _binary(topi.add), + "relax.divide": _binary(topi.divide), + "relax.floor_divide": _binary(topi.floor_divide), + "relax.multiply": _binary(topi.multiply), + "relax.subtract": _binary(topi.subtract), + "relax.equal": _binary(topi.equal), + "relax.greater": _binary(topi.greater), + "relax.greater_equal": _binary(topi.greater_equal), + "relax.less": _binary(topi.less), + "relax.less_equal": _binary(topi.less_equal), + "relax.not_equal": _binary(topi.not_equal), + # Creation + "relax.full": _full(is_like=False, fill_value=None, primfunc_name="full"), + "relax.full_like": _full(is_like=True, fill_value=None, primfunc_name="full"), + "relax.ones": _full(is_like=False, fill_value=1.0, primfunc_name="ones"), + "relax.ones_like": _full(is_like=True, fill_value=1.0, primfunc_name="ones"), + "relax.zeros": _full(is_like=False, fill_value=0.0, primfunc_name="zeros"), + "relax.zeros_like": _full(is_like=True, fill_value=0.0, primfunc_name="zeros"), + "relax.tril": _tril_triu(is_upper=False, primfunc_name="tril"), + "relax.triu": _tril_triu(is_upper=True, primfunc_name="triu"), + # Datatype + "relax.astype": _astype, + # Indexing + "relax.take": _take, + "relax.strided_slice": _strided_slice, + # Linear algebra + "relax.matmul": _matmul, + # Manipulation + "relax.broadcast_to": _reshape(topi.broadcast_to, "broadcast_to"), + "relax.concat": _concat, + "relax.expand_dims": _expand_dims, + "relax.flatten": _flatten, + "relax.permute_dims": _permute_dims, + "relax.reshape": _reshape(topi.reshape, "reshape"), + "relax.split": _split, + "relax.squeeze": _squeeze, + # TODO(relax-team): collapse_sum support symbolic shape + "relax.collapse_sum_like": _reshape( + topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True + ), + "relax.collapse_sum_to": _reshape(topi.collapse_sum, "collapse_sum"), + # Search + "relax.where": _where, + # Statistical + "relax.max": _statistical(topi.max), + "relax.mean": _mean, + "relax.min": _statistical(topi.min), + "relax.prod": _statistical(topi.prod), + "relax.std": _std, + "relax.sum": _statistical(topi.sum), + "relax.variance": _variance, + # Neural network + "relax.nn.conv2d": _nn_conv2d, + "relax.nn.max_pool2d": _nn_max_pool2d, + "relax.nn.adaptive_avg_pool2d": _nn_adaptive_max_pool2d, + "relax.nn.relu": _nn_relu, + "relax.nn.gelu": _nn_gelu, + "relax.nn.silu": _nn_silu, + "relax.nn.softmax": _nn_softmax, + "relax.nn.log_softmax": _nn_log_softmax, + "relax.nn.cross_entropy_without_logits": _nn_cross_entropy_without_logits, + "relax.nn.cross_entropy_with_logits": _nn_cross_entropy_with_logits, + "relax.nn.batch_norm": _nn_batch_norm, + "relax.nn.layer_norm": _nn_layer_norm, + "relax.nn.dropout": _nn_dropout, + "relax.nn.nll_loss": _nn_nll_loss, + # Image + "relax.image.resize2d": _image_resize2d, + # Todo(relax-team): Introduce cumsum for GPT-2 + # "relax.cumsum": _cumsum, +} + + +@tvm.transform.module_pass(opt_level=0, name="LegalizeOps") +class LegalizeOps: + """Legalize high-level operator calls in Relax functions to call_tir + with corresponding low-level TIR PrimFuncs. + + For each high-level operator, we register the way of legalizing it as a + function, which takes a context BlockBuilder and the Call being legalized + as input, and returns the legalized call. Here the input BlockBuilder is + mainly used for adding the PrimFunc created by call_te into the context + IRModule. + + The legalization function for each operator is registered in a map, + where the operator name is the key. The default legalization functions + are in the map `DEFAULT_OP_LEGALIZE_MAP`. + + This pass provides customizability for users to use their own legalization + function for operators. The pass takes an optional customized map, + which has the same key/value type as the default map (see `LegalizeFunc`), + from users. When an operator is contained in both the default map and the + customized map, the default legalization function will be overridden, and + only the customized one will be used. + + Parameters + ---------- + customize_legalize_map : Optional[Dict[str, LegalizeFunc]] + The customized operator legalization function map. + If not specified, it will be a fresh empty dict. + If an op has legalization function in both the default map and the + customized map, the customized function will override the default + one. + + Examples + -------- + The following code shows how to use this pass: + + .. code-block:: python + + # Define the pass input IRModule + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z: R.Tensor((2, 3), "float32") = R.add(x, y) + r: R.Tensor((2, 3), "float32") = R.multiply(y, z) + return r + + # Define the customized legalization function for "relax.add" + def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + from tvm import topi + return bb.call_te(topi.add, call.args[1], call.args[0]) + + # Apply the pass with the customized function to the module. + mod = LegalizeOps({"relax.add": customize_legalize_add})(Module) + + Print out the result by `mod.show()`, we can see the IRModule after + legalization becomes + + .. code-block:: python + + @tvm.script.ir_module + class Module: + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + z = R.call_tir(add, (y, x), (2, 3), dtype="float32") + r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32") + return r + + @T.prim_func + def add( + A: T.Buffer[(2, 3), "float32"], + B: T.Buffer[(2, 3), "float32"], + T_add: T.Buffer[(2, 3), "float32"], + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] + + @T.prim_func + def multiply( + A: T.Buffer[(2, 3), "float32"], + B: T.Buffer[(2, 3), "float32"], + T_multiply: T.Buffer[(2, 3), "float32"], + ): + T.func_attr({"tir.noalias": True}) + for ax0, ax1 in T.grid(2, 3): + with T.block("T_multiply"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) + T.writes(T_multiply[v_ax0, v_ax1]) + T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] + """ + + def __init__(self, customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): + if customize_legalize_map is None: + self.customize_legalize_map = dict() + else: + self.customize_legalize_map = customize_legalize_map + + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + @mutator + class OperatorLegalizer(PyExprMutator): + def __init__(self, mod: IRModule, customize_legalize_map: Dict[str, LegalizeFunc]): + super().__init__(mod) + self.mod = mod + self.legalize_map = DEFAULT_OP_LEGALIZE_MAP.copy() + for name, func in customize_legalize_map.items(): + self.legalize_map[name] = func + + def _convert_op(self, call: Call) -> Expr: + if call.op.name in self.legalize_map: + # We only transform the op calls with known shape values + if not all( + [has_known_shape_value(arg.struct_info) for arg in call.args] + ) or not has_known_shape_value(call.struct_info): + return call + return self.legalize_map[call.op.name](self.builder_, call) + if call.op.name != "relax.call_tir": + logging.warning("No legalization func for %s is found.", call.op.name) + return call + + def transform(self) -> IRModule: + for global_var, func in self.mod.functions.items(): + if not isinstance(func, Function): + continue + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(global_var, updated_func) + + return self.builder_.get() + + def visit_call_(self, call): # pylint: disable=arguments-differ + call = self.visit_expr_post_order(call) + if not isinstance(call.op, tir.op.Op): + return call + return self._convert_op(call) + + return OperatorLegalizer(mod, self.customize_legalize_map).transform() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py new file mode 100644 index 0000000000..f51c99fc15 --- /dev/null +++ b/tests/python/relax/test_frontend_from_fx.py @@ -0,0 +1,1716 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import relax as R, tir as T + + +def verify_model(torch_model, input_info, binding, expected): + from torch import fx + from tvm.relax.frontend.torch import from_fx + + graph_model = fx.symbolic_trace(torch_model) + mod = from_fx(graph_model, input_info) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +@tvm.testing.requires_gpu +def test_conv(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((1, 6, 1, 1), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, w2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Conv2D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy().reshape(1, 6, 1, 1)} + verify_model(model, input_info, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + +@tvm.testing.requires_gpu +def test_linear(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + # nn.Linear + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 7), dtype="float32"), + w2: R.Tensor((1, 7), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, w1, out_dtype="float32" + ) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv, w2) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1 + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 7), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, w1, out_dtype="float32" + ) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Dense1() + binding = {"w1": model.linear.weight.numpy().T, "w2": model.linear.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.numpy().T} + verify_model(model, input_info, binding, expected2) + + # matmul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + MatMul1(), + [([10, 10], "float32"), ([10, 10], "float32")], + {}, + expected3, + ) + + +@tvm.testing.requires_gpu +def test_relu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_relu6(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU6(Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, input): + return self.relu6(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU6(), input_info, {}, expected) + + +@tvm.testing.requires_gpu +def test_maxpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d2(), input_info, {}, expected2) + verify_model(MaxPool2d3(), input_info, {}, expected3) + + +@tvm.testing.requires_gpu +def test_adaptive_avgpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class AdaptiveAvgPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool2d(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_flatten(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 100), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tensor((1, 3, 100), dtype="float32") = lv + R.output(gv) + return gv + + # call_module + verify_model(Flatten(), input_info, {}, expected1) + # call_method + verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_batchnorm2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + model = BatchNorm2d() + binding = { + "w1": model.bn.weight.numpy(), + "w2": model.bn.bias.numpy(), + "w3": model.bn.running_mean.numpy(), + "w4": model.bn.running_var.numpy(), + } + verify_model(BatchNorm2d(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_embedding(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([4], "int64")] + + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tensor((4, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tensor((4, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + model = Embedding() + binding = {"w1": model.embedding.weight.numpy()} + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_dropout(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Dropout(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1 + R.output(gv) + return gv + + verify_model(Dropout(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_layernorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm() + binding = { + "w1": model.ln.weight.numpy(), + "w2": model.ln.bias.numpy(), + } + verify_model(LayerNorm(), input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_silu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(SiLU(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( + input_1, (1, 3, 1, 10, 10) + ) + lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( + lv, axis=[2, 3, 4], keepdims=True + ) + lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) + lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) + lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( + lv3, axis=[2, 3, 4], keepdims=True + ) + lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) + lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) + lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) + lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) + lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) + lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) + lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) + lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) + lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + R.output(gv) + return gv + + model = GroupNorm() + binding = { + "w1": model.gn.weight.numpy(), + "w2": model.gn.bias.numpy(), + } + verify_model(model, input_info, binding, expected1) + + +@tvm.testing.requires_gpu +def test_softmax(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Softmax(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_binary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + input_info2 = [([1, 3, 10, 10], "float32")] + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected1: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected2: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Add1(), input_info1, {}, expected1) + verify_model(Add2(), input_info2, {}, expected2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected3: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected4: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sub1(), input_info1, {}, expected3) + verify_model(Sub2(), input_info2, {}, expected4) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected5: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected6: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mul1(), input_info1, {}, expected5) + verify_model(Mul2(), input_info2, {}, expected6) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected7: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected8: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TrueDiv1(), input_info1, {}, expected7) + verify_model(TrueDiv2(), input_info2, {}, expected8) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected9: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected10: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(FloorDiv1(), input_info1, {}, expected9) + verify_model(FloorDiv2(), input_info2, {}, expected10) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected11: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected12: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(LT1(), input_info1, {}, expected11) + verify_model(LT2(), input_info2, {}, expected12) + + +@tvm.testing.requires_gpu +def test_size(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Size(Module): + def forward(self, input): + return input.size() + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10) + R.output(gv) + return gv + + verify_model(Size(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_unsqueeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Unsqueeze1(), input_info, {}, expected1) + verify_model(Unsqueeze2(), input_info, {}, expected2) + + +@tvm.testing.requires_gpu +def test_getattr(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GetAttr1(Module): + def forward(self, input): + return input.shape + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10) + R.output(gv) + return gv + + verify_model(GetAttr1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_getitem(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 10, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice( + x, + axes=[0, 1, 2, 3], + begin=[0, 1, 0, 0], + end=[1, T.int64(3), T.int64(10), 3], + strides=[1, 2, 1, 1], + ) + lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv, (1, 1, 10, 3)) + gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Slice1(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_unary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, input): + return torch.sin(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sin(), input_info, {}, expected1) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cos(), input_info, {}, expected2) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sqrt(), input_info, {}, expected3) + + +@tvm.testing.requires_gpu +def test_gelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Gelu(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gelu(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_clamp(): + import torch + from torch import fx + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Clamp(), input_info, {}, expected1) + + from tvm.relax.frontend.torch import from_fx + + with pytest.raises( + ValueError, match="TVM only supports constant max value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + with pytest.raises( + ValueError, match="TVM only supports constant min value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + +@tvm.testing.requires_gpu +def test_interpolate(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Interpolate(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(5, 5)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_addmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [ + ([10, 10], "float32"), + ([10, 10], "float32"), + ([10, 10], "float32"), + ] + + class Addmm(Module): + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tensor((10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Addmm(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_split(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Split(Module): + def forward(self, input): + return torch.split(input, 1, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_tril(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tril(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_new_ones(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3], "float32")] + + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((1, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewOnes(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_expand(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Expand(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((4, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Expand(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reduce(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tensor((1, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sum(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_to(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ToFloat(), input_info, {}, expected1) + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float16"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv + R.output(gv) + return gv + + verify_model(ToHalf(), input_info, {}, expected2) + + +@tvm.testing.requires_gpu +def test_permute(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Permute(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_reshape(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reshape(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_transpose(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Transpose(), input_info, {}, expected1) + + +@tvm.testing.requires_gpu +def test_view(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(View(), input_info, {}, expected1) + + +if __name__ == "__main__": + tvm.testing.main() From de522bb22e44bf661ff8891a27eb1260eac65f5e Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 9 Feb 2023 13:42:46 -0800 Subject: [PATCH 2/4] . --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1af821c849..04d146ff39 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, Mapping, Tuple, Union, List +from typing import Dict, Tuple, List from functools import reduce import tvm From 5eea2e648d7e39000f0c055884d877dc8a9ae6f6 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 9 Feb 2023 14:20:09 -0800 Subject: [PATCH 3/4] . --- python/tvm/relax/transform/legalize_ops.py | 865 --------------------- 1 file changed, 865 deletions(-) delete mode 100644 python/tvm/relax/transform/legalize_ops.py diff --git a/python/tvm/relax/transform/legalize_ops.py b/python/tvm/relax/transform/legalize_ops.py deleted file mode 100644 index 420445d39c..0000000000 --- a/python/tvm/relax/transform/legalize_ops.py +++ /dev/null @@ -1,865 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=abstract-method,invalid-name,missing-class-docstring,missing-function-docstring,missing-module-docstring,unused-argument -import logging -from typing import Callable, Dict, List, Optional, Union - -import tvm -from tvm import te, tir, topi, relax -from tvm.relax import struct_info -from tvm.ir.module import IRModule - -from ..analysis import remove_all_unused -from ..expr import Call, Constant, Expr, Function, ShapeExpr, Tuple, TupleGetItem, Var -from ..expr_functor import mutator, PyExprMutator -from ..block_builder import BlockBuilder - - -##################### Commons ##################### - -# The function type of a TE function, which accepts TE Tensors and -# other attributes, and returns the output TE Tensor. -TEFunc = Callable[..., te.Tensor] - -# The function type of a legalization function, which takes a -# BlockBuilder and the Call to be legalized, and outputs the legalization -# result Expr. -LegalizeFunc = Callable[[BlockBuilder, Call], Expr] - - -def has_known_shape_value(sinfo: struct_info.StructInfo) -> bool: - """Check if a given Tensor/Shape/TupleStructInfo contains - shapes whose values are all known. - - Parameters - ---------- - sinfo : struct_info.StructInfo - The struct info to be checked. - - Returns - ------- - ret : bool - A boolean indicating if the given struct info contains shape - values that are all known. - """ - if isinstance(sinfo, struct_info.TensorStructInfo): - return isinstance(sinfo.shape, ShapeExpr) - elif isinstance(sinfo, struct_info.ShapeStructInfo): - return sinfo.values is not None - elif isinstance(sinfo, struct_info.TupleStructInfo): - return all([has_known_shape_value(field_sinfo) for field_sinfo in sinfo.fields]) - elif isinstance(sinfo, struct_info.PrimStructInfo): - return True - else: - return False - - -def try_convert_to_scalar_const(expr: Expr) -> Union[Expr, bool, float, int]: - """Check if the input Expr is a scalar constant. - If it is, return its plain value. - If it is not, return the input expr. - - Parameters - ---------- - expr : Expr - The expr to be checked and converted. - - Returns - --–---- - ret : Union[Expr, bool, float, int] - Return a Python native value (int/float/bool) if the given - expr is a scalar constant. Or return the input itself - if it is not. - """ - if isinstance(expr, Constant) and expr.struct_info.ndim == 0: - return expr.data.numpy()[()].item() - else: - return expr - - -def _unary(te_func: TEFunc) -> LegalizeFunc: - def unary_call_te(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(te_func, call.args[0]) - - return unary_call_te - - -def _binary(te_func: TEFunc) -> LegalizeFunc: - def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: - # To simplify the created PrimFunc, we first check if arg1 is a constant scalar. - # If it is not, we then check if arg0 is a constant scalar. - arg0 = call.args[0] - arg1 = try_convert_to_scalar_const(call.args[1]) - if isinstance(arg1, Expr): # type: ignore - arg0 = try_convert_to_scalar_const(arg0) - return bb.call_te(te_func, arg0, arg1) - - return binary_call_te - - -##################### Creation ##################### - - -def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> LegalizeFunc: - def full_call_te(bb: BlockBuilder, call: Call) -> Expr: - _fill_value = ( - try_convert_to_scalar_const(call.args[1]) if fill_value is None else fill_value - ) - - return bb.call_te( - topi.full, - call.args[0].struct_info.shape if is_like else call.args[0], - call.struct_info.dtype, - _fill_value, - primfunc_name_hint=primfunc_name, - ) - - return full_call_te - - -def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc: - def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.trilu, - call.args[0], - tir.const(call.attrs.k, "int32"), - upper=is_upper, - primfunc_name_hint=primfunc_name, - ) - - return tril_triu_call_te - - -##################### Datatype ##################### - - -def _astype(bb: BlockBuilder, call: Call) -> Expr: - arg = try_convert_to_scalar_const(call.args[0]) - if isinstance(arg, Expr): # type: ignore - return bb.call_te(topi.cast, arg, call.attrs.dtype) - else: - return relax.const(arg, call.attrs.dtype) - - -##################### Indexing ##################### - - -def _take(bb: BlockBuilder, call: Call) -> Expr: - # Currently Relax `take` operator doesn't provide the mode choices and - # requires input indices to be in range. - # We use fast mode, which leads to runtime error whenever some index is - # out of bound. - return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, mode="fast") - - -def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: - if not all( - [ - isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm) - for i in call.attrs.axes - ] - ): - logging.info( - "Cases where an axis with symbolic length is sliced are not able " - "to be legalized through TOPI" - ) - return call - - return bb.call_te( - topi.strided_slice, - call.args[0], - call.attrs.begin, - call.attrs.end, - call.attrs.strides, - call.attrs.axes, - slice_mode="end", - ) - - -##################### Linear algebra ##################### - - -def _matmul(bb: BlockBuilder, call: Call) -> Expr: - def te_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: - a_shape = list(a.shape) - b_shape = list(b.shape) - a_prepended = False - b_appended = False - if len(a_shape) == 1: - a_prepended = True - a_shape.insert(0, 1) - if len(b_shape) == 1: - b_appended = True - b_shape.append(1) - - is_a_larger = len(a_shape) > len(b_shape) - offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) - - a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) - b_relax = relax.Var("b", relax.TensorStructInfo(b.shape)) - f_infer_sinfo = call.op.get_attr("FInferStructInfo") - output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), bb).shape - - def matmul_compute(*idx_spatial): - k = te.reduce_axis((0, a_shape[-1]), name="k") - - def multiply_compute(idx_reduce): - a_indices = [] - b_indices = [] - - for i in range(offset): - if is_a_larger: - a_indices.append(idx_spatial[i]) - else: - b_indices.append(idx_spatial[i]) - for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): - a_dim = a_shape[i if is_a_larger else i - offset] - b_dim = b_shape[i if not is_a_larger else i - offset] - a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 - b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 - a_indices.append(0 if a_dim_is_one else idx_spatial[i]) - b_indices.append(0 if b_dim_is_one else idx_spatial[i]) - if not a_prepended: - a_indices.append(idx_spatial[-2 + b_appended]) - a_indices.append(idx_reduce) - b_indices.append(idx_reduce) - if not b_appended: - b_indices.append(idx_spatial[-1]) - - dtype = call.attrs.out_dtype - if dtype != "": - return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) - else: - return a(*a_indices) * b(*b_indices) - - return te.sum(multiply_compute(k), axis=k) - - return te.compute( - output_shape, - lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda - name="matmul", - ) - - return bb.call_te(te_matmul, call.args[0], call.args[1], primfunc_name_hint="matmul") - - -##################### Manipulation ##################### - - -def _reshape( - te_func: TEFunc, primfunc_name: str, is_collapse_sum_like: bool = False -) -> LegalizeFunc: - def reshape_call_te(bb: BlockBuilder, call: Call): - tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] - return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name) - - return reshape_call_te - - -def _concat(bb: BlockBuilder, call: Call) -> Expr: - t = call.args[0] - n_field = len(t.struct_info.fields) - while isinstance(t, Var): - binding = bb.lookup_binding(t) - if not isinstance(binding, (Tuple, Var)): - break - t = binding - - assert isinstance(t, (Tuple, Var)) - fields = ( - t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] - ) - return bb.call_te( - topi.concatenate, fields, None if call.attrs.axis is None else call.attrs.axis.value - ) - - -def _expand_dims(bb: BlockBuilder, call: Call) -> Expr: - def te_expand_dims(data, axis): - data_relax = relax.Var("data", relax.TensorStructInfo(data.shape)) - f_infer_sinfo = call.op.get_attr("FInferStructInfo") - output_shape = f_infer_sinfo(relax.op.expand_dims(data_relax, axis), bb).shape - output_ndim = len(output_shape) - - data_dims = [] - for i in range(output_ndim): - if i not in axis and (i - output_ndim) not in axis: - data_dims.append(i) - return te.compute( - output_shape, - lambda *idx: data(*[idx[dim] for dim in data_dims]), - name="expand_dims", - ) - - return bb.call_te( - te_expand_dims, call.args[0], call.attrs.axis, primfunc_name_hint="expand_dims" - ) - - -def _flatten(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.reshape, call.args[0], call.struct_info.shape.values) - - -def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.transpose, call.args[0], call.attrs.axes) - - -def _split(bb: BlockBuilder, call: Call) -> Expr: - if isinstance(call.attrs.indices_or_sections, tir.IntImm): - indices_or_sections = call.attrs.indices_or_sections.value - modulo = tvm.arith.Analyzer().simplify( - call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections - ) - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call - else: - indices_or_sections = call.attrs.indices_or_sections - return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) - - -def _squeeze(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis) - - -##################### Search ##################### - - -def _where(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.where, call.args[0], call.args[1], call.args[2]) - - -##################### Statistical ##################### - - -def _statistical(te_func: TEFunc) -> LegalizeFunc: - def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(te_func, call.args[0], call.attrs.axis, call.attrs.keepdims) - - return statistical_call_te - - -def _compute_shape_prod(x: te.Tensor, axis: List[tir.IntImm]) -> tir.PrimExpr: - shape_prod = tir.const(1, "int32") - axes = [_axis.value for _axis in axis] if axis is not None else range(0, len(x.shape)) - for dim in axes: - shape_prod = shape_prod * x.shape[dim] - return shape_prod - - -def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: - shape_prod = _compute_shape_prod(x, axis) - res_sum = topi.sum(x, axis, keepdims) - return topi.divide(res_sum, shape_prod) - - -def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: - dev = x - _te_mean(x, axis, keepdims) - return _te_mean(dev * dev, axis, keepdims) - - -def _mean(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - _te_mean, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="mean" - ) - - -def _std(bb: BlockBuilder, call: Call) -> Expr: - def te_std(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor: - return topi.sqrt(_te_variance(x, axis, keepdims)) - - return bb.call_te( - te_std, call.args[0], call.attrs.axis, call.attrs.keepdims, primfunc_name_hint="std" - ) - - -def _variance(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - _te_variance, - call.args[0], - call.attrs.axis, - call.attrs.keepdims, - primfunc_name_hint="variance", - ) - - -##################### Neural network ##################### - - -def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: - if call.attrs.out_layout != call.attrs.data_layout: - logging.info( - "TOPI conv2d does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" - ) - return call - if len(call.attrs.data_layout) != 4 or len(call.attrs.kernel_layout) != 4: - logging.info( - "Conv2D where data layout or kernel layout have channel chunk " - "cannot be legalized by TOPI at this moment." - ) - return call - if call.attrs.groups != 1: - data_layout = tir.layout(call.attrs.data_layout) - kernel_layout = tir.layout(call.attrs.kernel_layout) - ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] - oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] - if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): - logging.info( - "Conv2D where number of groups is more than one and input or output " - "channel size is symbolic cannot be legalized by TOPI at this moment." - ) - return call - - return bb.call_te( - topi.nn.conv, - inp=call.args[0], - filt=call.args[1], - stride=call.attrs.strides, - padding=call.attrs.padding, - dilation=call.attrs.dilation, - groups=call.attrs.groups, - data_layout=call.attrs.data_layout, - kernel_layout=call.attrs.kernel_layout, - out_dtype=call.attrs.out_dtype if call.attrs.out_dtype != "" else None, - primfunc_name_hint="conv2d", - ) - - -def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: - if call.attrs.out_layout != call.attrs.layout: - logging.info( - "TOPI max_pool2d does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" - ) - return call - - return bb.call_te( - topi.nn.pool2d, - call.args[0], - kernel=call.attrs.pool_size, - stride=call.attrs.strides, - dilation=call.attrs.dilation, - padding=call.attrs.padding, - pool_type="max", - ceil_mode=call.attrs.ceil_mode, - layout=call.attrs.layout, - primfunc_name_hint="max_pool2d", - ) - - -def _nn_adaptive_max_pool2d(bb: BlockBuilder, call: Call) -> Expr: - if call.attrs.out_layout != call.attrs.layout: - logging.info( - "TOPI adaptive_max_pool2d does not support different input-output " - "layouts, and thus cannot be legalized by TOPI" - ) - return call - - def te_adaptive_avg_pool2d(data, output_size, layout_str): - if output_size is None: - layout = tir.layout(layout_str) - idx_H = layout.index_of("H") - idx_W = layout.index_of("W") - assert idx_H != -1 and idx_W != -1 - output_size = (data.shape[idx_H], data.shape[idx_W]) - - return topi.nn.adaptive_pool(data, output_size, "avg", layout_str) - - return bb.call_te( - te_adaptive_avg_pool2d, - call.args[0], - call.attrs.output_size, - call.attrs.layout, - primfunc_name_hint="adaptive_avg_pool2d", - ) - - -def _nn_relu(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.nn.relu, call.args[0]) - - -def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: - def gelu(x: te.Tensor): - dtype = x.dtype - return x * ( - tir.const(0.5, dtype) - + topi.erf(x * tir.const(0.5**0.5, dtype)) * tir.const(0.5, dtype) - ) - - return bb.call_te(gelu, call.args[0], primfunc_name_hint="gelu") - - -def _nn_silu(bb: BlockBuilder, call: Call) -> Expr: - def te_silu(x: te.Tensor): - return topi.multiply(x, topi.sigmoid(x)) - - return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") - - -def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) - - -def _nn_log_softmax(bb: BlockBuilder, call: Call): - return bb.call_te(topi.nn.log_softmax, call.args[0], call.attrs.axis) - - -def _nn_cross_entropy_without_logits(bb: BlockBuilder, call: Call): - def te_cross_entropy_without_logits(x, y): - if len(x.shape) > 1: - return -topi.sum(topi.log(x) * y) / x.shape[0] - return -topi.sum(topi.log(x) * y) - - return bb.call_te( - te_cross_entropy_without_logits, - call.args[0], - call.args[1], - primfunc_name_hint="cross_entropy_without_logits", - ) - - -def _nn_cross_entropy_with_logits(bb: BlockBuilder, call: Call): - def te_cross_entropy_with_logits(x, y): - if len(x.shape) > 1: - return -topi.sum(x * y) / x.shape[0] - return -topi.sum(x * y) - - return bb.call_te( - te_cross_entropy_with_logits, - call.args[0], - call.args[1], - primfunc_name_hint="cross_entropy_with_logits", - ) - - -def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.nn.batch_norm, - data=call.args[0], - gamma=call.args[1], - beta=call.args[2], - moving_mean=call.args[3], - moving_var=call.args[4], - axis=call.attrs.axis, - epsilon=call.attrs.epsilon, - center=call.attrs.center, - scale=call.attrs.scale, - ) - - -def _nn_layer_norm(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.nn.layer_norm, - call.args[0], - call.args[1], - call.args[2], - axis=call.attrs.axes, - epsilon=call.attrs.epsilon, - ) - - -def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr: - logging.info("Dropout is handled by frontend translator at this moment and is not legalized.") - return call - - -def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr: - if len(call.args) == 2: - # TODO(relax-team): handle optional arugment weight of NLLLoss - logging.info( - "Can not legalize it now, because don't know how to set " - "the default value of the optional argument 'weight' of NLLLoss." - ) - return call - return bb.call_te( - topi.nn.nll_loss, - call.args[0], - call.args[1], - call.args[2], - reduction=call.attrs.reduction, - ignore_index=call.attrs.ignore_index, - ) - - -##################### Image ##################### - - -def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.image.resize2d, - call.args[0], - roi=call.attrs.roi, - size=call.args[1], - layout=call.attrs.layout, - method=call.attrs.method, - coordinate_transformation_mode=call.attrs.coordinate_transformation_mode, - rounding_method=call.attrs.rounding_method, - bicubic_alpha=call.attrs.cubic_alpha, - bicubic_exclude=call.attrs.cubic_exclude, - extrapolation_value=call.attrs.extrapolation_value, - ) - - -##################### Common ##################### - - -def _call_topi(te_func: TEFunc) -> LegalizeFunc: - return lambda bb, call: bb.call_te(te_func, *call.args) - - -########################################################## - - -# Todo(relax-team): Introduce cumsum for GPT-2 -# def _cumsum(bb: BlockBuilder, call: Call): -# return bb.call_te(topi.cumsum, args[0], attrs.axis) - - -DEFAULT_OP_LEGALIZE_MAP: Dict[str, LegalizeFunc] = { - # Arithmetic and comparison - "relax.abs": _unary(topi.abs), - "relax.cos": _unary(topi.cos), - "relax.log": _unary(topi.log), - "relax.exp": _unary(topi.exp), - "relax.negative": _unary(topi.negative), - "relax.sigmoid": _unary(topi.sigmoid), - "relax.sin": _unary(topi.sin), - "relax.sqrt": _unary(topi.sqrt), - "relax.tanh": _unary(topi.tanh), - "relax.clip": _call_topi(topi.clip), - "relax.add": _binary(topi.add), - "relax.divide": _binary(topi.divide), - "relax.floor_divide": _binary(topi.floor_divide), - "relax.multiply": _binary(topi.multiply), - "relax.subtract": _binary(topi.subtract), - "relax.equal": _binary(topi.equal), - "relax.greater": _binary(topi.greater), - "relax.greater_equal": _binary(topi.greater_equal), - "relax.less": _binary(topi.less), - "relax.less_equal": _binary(topi.less_equal), - "relax.not_equal": _binary(topi.not_equal), - # Creation - "relax.full": _full(is_like=False, fill_value=None, primfunc_name="full"), - "relax.full_like": _full(is_like=True, fill_value=None, primfunc_name="full"), - "relax.ones": _full(is_like=False, fill_value=1.0, primfunc_name="ones"), - "relax.ones_like": _full(is_like=True, fill_value=1.0, primfunc_name="ones"), - "relax.zeros": _full(is_like=False, fill_value=0.0, primfunc_name="zeros"), - "relax.zeros_like": _full(is_like=True, fill_value=0.0, primfunc_name="zeros"), - "relax.tril": _tril_triu(is_upper=False, primfunc_name="tril"), - "relax.triu": _tril_triu(is_upper=True, primfunc_name="triu"), - # Datatype - "relax.astype": _astype, - # Indexing - "relax.take": _take, - "relax.strided_slice": _strided_slice, - # Linear algebra - "relax.matmul": _matmul, - # Manipulation - "relax.broadcast_to": _reshape(topi.broadcast_to, "broadcast_to"), - "relax.concat": _concat, - "relax.expand_dims": _expand_dims, - "relax.flatten": _flatten, - "relax.permute_dims": _permute_dims, - "relax.reshape": _reshape(topi.reshape, "reshape"), - "relax.split": _split, - "relax.squeeze": _squeeze, - # TODO(relax-team): collapse_sum support symbolic shape - "relax.collapse_sum_like": _reshape( - topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True - ), - "relax.collapse_sum_to": _reshape(topi.collapse_sum, "collapse_sum"), - # Search - "relax.where": _where, - # Statistical - "relax.max": _statistical(topi.max), - "relax.mean": _mean, - "relax.min": _statistical(topi.min), - "relax.prod": _statistical(topi.prod), - "relax.std": _std, - "relax.sum": _statistical(topi.sum), - "relax.variance": _variance, - # Neural network - "relax.nn.conv2d": _nn_conv2d, - "relax.nn.max_pool2d": _nn_max_pool2d, - "relax.nn.adaptive_avg_pool2d": _nn_adaptive_max_pool2d, - "relax.nn.relu": _nn_relu, - "relax.nn.gelu": _nn_gelu, - "relax.nn.silu": _nn_silu, - "relax.nn.softmax": _nn_softmax, - "relax.nn.log_softmax": _nn_log_softmax, - "relax.nn.cross_entropy_without_logits": _nn_cross_entropy_without_logits, - "relax.nn.cross_entropy_with_logits": _nn_cross_entropy_with_logits, - "relax.nn.batch_norm": _nn_batch_norm, - "relax.nn.layer_norm": _nn_layer_norm, - "relax.nn.dropout": _nn_dropout, - "relax.nn.nll_loss": _nn_nll_loss, - # Image - "relax.image.resize2d": _image_resize2d, - # Todo(relax-team): Introduce cumsum for GPT-2 - # "relax.cumsum": _cumsum, -} - - -@tvm.transform.module_pass(opt_level=0, name="LegalizeOps") -class LegalizeOps: - """Legalize high-level operator calls in Relax functions to call_tir - with corresponding low-level TIR PrimFuncs. - - For each high-level operator, we register the way of legalizing it as a - function, which takes a context BlockBuilder and the Call being legalized - as input, and returns the legalized call. Here the input BlockBuilder is - mainly used for adding the PrimFunc created by call_te into the context - IRModule. - - The legalization function for each operator is registered in a map, - where the operator name is the key. The default legalization functions - are in the map `DEFAULT_OP_LEGALIZE_MAP`. - - This pass provides customizability for users to use their own legalization - function for operators. The pass takes an optional customized map, - which has the same key/value type as the default map (see `LegalizeFunc`), - from users. When an operator is contained in both the default map and the - customized map, the default legalization function will be overridden, and - only the customized one will be used. - - Parameters - ---------- - customize_legalize_map : Optional[Dict[str, LegalizeFunc]] - The customized operator legalization function map. - If not specified, it will be a fresh empty dict. - If an op has legalization function in both the default map and the - customized map, the customized function will override the default - one. - - Examples - -------- - The following code shows how to use this pass: - - .. code-block:: python - - # Define the pass input IRModule - @tvm.script.ir_module - class Module: - @R.function - def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") - ) -> R.Tensor((2, 3), "float32"): - z: R.Tensor((2, 3), "float32") = R.add(x, y) - r: R.Tensor((2, 3), "float32") = R.multiply(y, z) - return r - - # Define the customized legalization function for "relax.add" - def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: - from tvm import topi - return bb.call_te(topi.add, call.args[1], call.args[0]) - - # Apply the pass with the customized function to the module. - mod = LegalizeOps({"relax.add": customize_legalize_add})(Module) - - Print out the result by `mod.show()`, we can see the IRModule after - legalization becomes - - .. code-block:: python - - @tvm.script.ir_module - class Module: - @R.function - def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") - ) -> R.Tensor((2, 3), "float32"): - z = R.call_tir(add, (y, x), (2, 3), dtype="float32") - r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32") - return r - - @T.prim_func - def add( - A: T.Buffer[(2, 3), "float32"], - B: T.Buffer[(2, 3), "float32"], - T_add: T.Buffer[(2, 3), "float32"], - ): - T.func_attr({"tir.noalias": True}) - for ax0, ax1 in T.grid(2, 3): - with T.block("T_add"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) - T.writes(T_add[v_ax0, v_ax1]) - T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] - - @T.prim_func - def multiply( - A: T.Buffer[(2, 3), "float32"], - B: T.Buffer[(2, 3), "float32"], - T_multiply: T.Buffer[(2, 3), "float32"], - ): - T.func_attr({"tir.noalias": True}) - for ax0, ax1 in T.grid(2, 3): - with T.block("T_multiply"): - v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) - T.writes(T_multiply[v_ax0, v_ax1]) - T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] - """ - - def __init__(self, customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None): - if customize_legalize_map is None: - self.customize_legalize_map = dict() - else: - self.customize_legalize_map = customize_legalize_map - - def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: - @mutator - class OperatorLegalizer(PyExprMutator): - def __init__(self, mod: IRModule, customize_legalize_map: Dict[str, LegalizeFunc]): - super().__init__(mod) - self.mod = mod - self.legalize_map = DEFAULT_OP_LEGALIZE_MAP.copy() - for name, func in customize_legalize_map.items(): - self.legalize_map[name] = func - - def _convert_op(self, call: Call) -> Expr: - if call.op.name in self.legalize_map: - # We only transform the op calls with known shape values - if not all( - [has_known_shape_value(arg.struct_info) for arg in call.args] - ) or not has_known_shape_value(call.struct_info): - return call - return self.legalize_map[call.op.name](self.builder_, call) - if call.op.name != "relax.call_tir": - logging.warning("No legalization func for %s is found.", call.op.name) - return call - - def transform(self) -> IRModule: - for global_var, func in self.mod.functions.items(): - if not isinstance(func, Function): - continue - updated_func = self.visit_expr(func) - updated_func = remove_all_unused(updated_func) - self.builder_.update_func(global_var, updated_func) - - return self.builder_.get() - - def visit_call_(self, call): # pylint: disable=arguments-differ - call = self.visit_expr_post_order(call) - if not isinstance(call.op, tir.op.Op): - return call - return self._convert_op(call) - - return OperatorLegalizer(mod, self.customize_legalize_map).transform() From bde40ca0fc0ffd94ea6cec649754cfb3f393951b Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Fri, 10 Feb 2023 12:33:51 -0800 Subject: [PATCH 4/4] . --- .../tvm/relax/frontend/torch/fx_translator.py | 148 +++++++++--------- 1 file changed, 73 insertions(+), 75 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 04d146ff39..910a5deffb 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -680,79 +680,7 @@ def create_convert_map(self): } def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: - """Convert a PyTorch FX GraphModule to a Relax program - - Parameters - ---------- - model : fx.GraphModule - The PyTorch FX GraphModule to convert. - - input_info : List[Tuple[Tuple[int], str]] - A list of shapes and data types of input tensors. - - Returns - ------- - module : tvm.IRModule - The converted Relax program. - - Examples - -------- - Users can use the FX tracer or dynamo.export() to extract - a fx.GraphModule from a PyTorch model. The following codes show - how to convert a PyTorch model to a Relax program. - - .. code-block:: python - - # Import the importer. - import numpy as np - import torch - from tvm.relax.frontend.torch_fx import from_fx - from torch import _dynamo as dynamo - - # Define the module - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) - - def forward(self, input): - return self.linear(input) - - # Instantiate the model and create the input info dict. - torch_model = MyModule() - input_info = [((128, 10), "float32")] - input_tensors = [ - torch.astensor(np.random.randn(*shape).astype(dtype)) - for shape, dtype in input_info - ] - - # Use FX tracer to trace the PyTorch model. - graph_module = fx.symbolic_trace(torch_model) - - # Use the dynamo.export() to export the PyTorch model to FX. - try: - graph_module = dynamo.export(torch_model, *input_tensors) - except: - raise RuntimeError("Failed to export the PyTorch model to FX.") - - # Use the importer to import the PyTorch model to Relax. - mod: tvm.IRModule = from_pytorch(graph_module, input_info) - - # Print out the imported model. - print(mod.script()) - - Notes - ----- - For a given PyTorch model, to lookup the names of the model inputs in - FX, one can use - - .. code-block:: python - - fx.symbolic_trace(model).graph.print_tabular() - - to print out the tabular representation of the PyTorch module, and then - check the placeholder rows in the beginning of the tabular. - """ + """Convert a PyTorch FX GraphModule to a Relax program.""" from torch import fx self.named_modules = dict(model.named_modules()) @@ -819,7 +747,77 @@ def forward(self, input): def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: - """The public interface of PyTorch FX importer for Relax. - See `TorchFXImporter.from_fx` for full documentation. + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : fx.GraphModule + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + Returns + ------- + module : tvm.IRModule + The converted Relax program. + + Examples + -------- + Users can use the FX tracer or dynamo.export() to extract + a fx.GraphModule from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_fx + from torch import _dynamo as dynamo + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the dynamo.export() to export the PyTorch model to FX. + try: + graph_module = dynamo.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to FX.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_pytorch(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + fx.symbolic_trace(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. """ return TorchFXImporter().from_fx(model, input_info)