|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | +from torch.export import export |
| 21 | + |
| 22 | +import tvm |
| 23 | +import tvm.testing |
| 24 | +from tvm import relax |
| 25 | +from tvm.relax.frontend.torch import from_exported_program |
| 26 | + |
| 27 | + |
| 28 | +def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev): |
| 29 | + """ |
| 30 | + This util ensures that a torch module can successfully be exported to TVM |
| 31 | + using torch.export and that the resuling IR program gives the same result |
| 32 | + as PyTorch when ran on CUDA. |
| 33 | + """ |
| 34 | + raw_data_for_tvm = raw_data.copy() # In case the data is modified |
| 35 | + torch_data = torch.from_numpy(raw_data) |
| 36 | + example_args = (torch_data,) |
| 37 | + |
| 38 | + with torch.no_grad(): |
| 39 | + exported_program = export(torch_module, example_args) |
| 40 | + mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True) |
| 41 | + |
| 42 | + tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch) |
| 43 | + |
| 44 | + relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda())) |
| 45 | + # TODO try pipeline below? |
| 46 | + # releax_pipeline = relax.backend.cuda.pipeline.get_default_pipeline(target) |
| 47 | + ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline) |
| 48 | + vm = relax.VirtualMachine(ex, dev) |
| 49 | + |
| 50 | + gpu_data = tvm.nd.array(raw_data_for_tvm, dev) |
| 51 | + gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]] |
| 52 | + gpu_out = vm["main"](gpu_data, *gpu_params) |
| 53 | + |
| 54 | + pytorch_out = torch_module(torch_data).detach().numpy() |
| 55 | + actual = gpu_out[0].numpy() |
| 56 | + desired = pytorch_out |
| 57 | + np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) |
| 58 | + |
| 59 | + |
| 60 | +@tvm.testing.parametrize_targets("cuda") |
| 61 | +def test_linalg_vector_norm(target, dev): |
| 62 | + class VectorNorm0(torch.nn.Module): |
| 63 | + def forward(self, x): |
| 64 | + return torch.linalg.vector_norm(x, ord=1, dim=-1) |
| 65 | + |
| 66 | + class VectorNorm1(torch.nn.Module): |
| 67 | + def forward(self, x): |
| 68 | + return torch.linalg.vector_norm(x, ord=2, dim=2) |
| 69 | + |
| 70 | + class VectorNorm2(torch.nn.Module): |
| 71 | + def forward(self, x): |
| 72 | + return torch.linalg.vector_norm(x, ord=1, dim=-1) |
| 73 | + |
| 74 | + class VectorNorm3(torch.nn.Module): |
| 75 | + def forward(self, x): |
| 76 | + return torch.linalg.vector_norm(x, ord=2, dim=2) |
| 77 | + |
| 78 | + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) |
| 79 | + |
| 80 | + torch_module0 = VectorNorm0().eval() |
| 81 | + torch_module1 = VectorNorm1().eval() |
| 82 | + torch_module2 = VectorNorm2().eval() |
| 83 | + torch_module3 = VectorNorm3().eval() |
| 84 | + |
| 85 | + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) |
| 86 | + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) |
| 87 | + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) |
| 88 | + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) |
| 89 | + |
| 90 | + |
| 91 | +if __name__ == "__main__": |
| 92 | + tvm.testing.main() |
0 commit comments