Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,42 @@ def call_binary_op(op, lhs, rhs):

return convert

########## Linear Algebra ##########

def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:

args = self.retrieve_args(node)

data = args[0]
# Default ord=2 if not supplied
ord_val = args[1] if len(args) > 1 else 2.0
dim = args[2] if len(args) > 2 else None
keepdim = args[3] if len(args) > 3 else False

# If ord_val is a Python float/int, wrap it in a Relax const
# so that it matches data's dtype.
dtype = data.struct_info.dtype
ord_expr = (
ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype)
)
# Reciprocal
reci_expr = (
relax.op.divide(relax.const(1.0, dtype), ord_expr)
if isinstance(ord_val, relax.Expr)
else relax.const(1.0 / float(ord_val), dtype)
)

# abs(data)
abs_data = self.block_builder.emit(relax.op.abs(data))
# abs_data^ord
abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr))
# sum over dim
reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim))
# (sum(...))^(1/ord)
norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))

return norm_val

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def create_convert_map(
"__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor),
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor),
# linear algebra
"linalg_vector_norm.default": self._linalg_vector_norm,
# neural network
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
Expand Down
92 changes: 92 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
import torch
from torch.export import export

import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program


def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev):
"""
This util ensures that a torch module can successfully be exported to TVM
using torch.export and that the resuling IR program gives the same result
as PyTorch when ran on CUDA.
"""
raw_data_for_tvm = raw_data.copy() # In case the data is modified
torch_data = torch.from_numpy(raw_data)
example_args = (torch_data,)

with torch.no_grad():
exported_program = export(torch_module, example_args)
mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True)

tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch)

relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda()))
# TODO try pipeline below?
# releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target)
ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline)
vm = relax.VirtualMachine(ex, dev)

gpu_data = tvm.nd.array(raw_data_for_tvm, dev)
gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params)

pytorch_out = torch_module(torch_data).detach().numpy()
actual = gpu_out[0].numpy()
desired = pytorch_out
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
def test_linalg_vector_norm(target, dev):
class VectorNorm0(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=1, dim=-1)

class VectorNorm1(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=2, dim=2)

class VectorNorm2(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=1, dim=-1)

class VectorNorm3(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=2, dim=2)

raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)

torch_module0 = VectorNorm0().eval()
torch_module1 = VectorNorm1().eval()
torch_module2 = VectorNorm2().eval()
torch_module3 = VectorNorm3().eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)


if __name__ == "__main__":
tvm.testing.main()