Skip to content

Commit d0de906

Browse files
[Relax] Allow ingesting vector_norm from torch.export (#17722)
- Implement torch's vector_norm as a function of existing relax ops - add a unit test
1 parent 2601733 commit d0de906

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,42 @@ def call_binary_op(op, lhs, rhs):
306306

307307
return convert
308308

309+
########## Linear Algebra ##########
310+
311+
def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
312+
313+
args = self.retrieve_args(node)
314+
315+
data = args[0]
316+
# Default ord=2 if not supplied
317+
ord_val = args[1] if len(args) > 1 else 2.0
318+
dim = args[2] if len(args) > 2 else None
319+
keepdim = args[3] if len(args) > 3 else False
320+
321+
# If ord_val is a Python float/int, wrap it in a Relax const
322+
# so that it matches data's dtype.
323+
dtype = data.struct_info.dtype
324+
ord_expr = (
325+
ord_val if isinstance(ord_val, relax.Expr) else relax.const(float(ord_val), dtype)
326+
)
327+
# Reciprocal
328+
reci_expr = (
329+
relax.op.divide(relax.const(1.0, dtype), ord_expr)
330+
if isinstance(ord_val, relax.Expr)
331+
else relax.const(1.0 / float(ord_val), dtype)
332+
)
333+
334+
# abs(data)
335+
abs_data = self.block_builder.emit(relax.op.abs(data))
336+
# abs_data^ord
337+
abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, ord_expr))
338+
# sum over dim
339+
reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, keepdims=keepdim))
340+
# (sum(...))^(1/ord)
341+
norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))
342+
343+
return norm_val
344+
309345
########## Neural Network ##########
310346

311347
def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ def create_convert_map(
231231
"__or__.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
232232
"__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor),
233233
"__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor),
234+
# linear algebra
235+
"linalg_vector_norm.default": self._linalg_vector_norm,
234236
# neural network
235237
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
236238
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import numpy as np
19+
import torch
20+
from torch.export import export
21+
22+
import tvm
23+
import tvm.testing
24+
from tvm import relax
25+
from tvm.relax.frontend.torch import from_exported_program
26+
27+
28+
def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev):
29+
"""
30+
This util ensures that a torch module can successfully be exported to TVM
31+
using torch.export and that the resuling IR program gives the same result
32+
as PyTorch when ran on CUDA.
33+
"""
34+
raw_data_for_tvm = raw_data.copy() # In case the data is modified
35+
torch_data = torch.from_numpy(raw_data)
36+
example_args = (torch_data,)
37+
38+
with torch.no_grad():
39+
exported_program = export(torch_module, example_args)
40+
mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True)
41+
42+
tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch)
43+
44+
relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda()))
45+
# TODO try pipeline below?
46+
# releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target)
47+
ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline)
48+
vm = relax.VirtualMachine(ex, dev)
49+
50+
gpu_data = tvm.nd.array(raw_data_for_tvm, dev)
51+
gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
52+
gpu_out = vm["main"](gpu_data, *gpu_params)
53+
54+
pytorch_out = torch_module(torch_data).detach().numpy()
55+
actual = gpu_out[0].numpy()
56+
desired = pytorch_out
57+
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)
58+
59+
60+
@tvm.testing.parametrize_targets("cuda")
61+
def test_linalg_vector_norm(target, dev):
62+
class VectorNorm0(torch.nn.Module):
63+
def forward(self, x):
64+
return torch.linalg.vector_norm(x, ord=1, dim=-1)
65+
66+
class VectorNorm1(torch.nn.Module):
67+
def forward(self, x):
68+
return torch.linalg.vector_norm(x, ord=2, dim=2)
69+
70+
class VectorNorm2(torch.nn.Module):
71+
def forward(self, x):
72+
return torch.linalg.vector_norm(x, ord=1, dim=-1)
73+
74+
class VectorNorm3(torch.nn.Module):
75+
def forward(self, x):
76+
return torch.linalg.vector_norm(x, ord=2, dim=2)
77+
78+
raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)
79+
80+
torch_module0 = VectorNorm0().eval()
81+
torch_module1 = VectorNorm1().eval()
82+
torch_module2 = VectorNorm2().eval()
83+
torch_module3 = VectorNorm3().eval()
84+
85+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
86+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
87+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
88+
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)
89+
90+
91+
if __name__ == "__main__":
92+
tvm.testing.main()

0 commit comments

Comments
 (0)