|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +""" |
| 19 | +.. _ir_module: |
| 20 | +
|
| 21 | +IRModule |
| 22 | +======== |
| 23 | +This tutorial presents the core abstraction of Apache TVM Unity, the IRModule. |
| 24 | +The IRModule encompasses the **entirety** of the ML models, incorporating the |
| 25 | +computational graph, tensor programs, and potential calls to external libraries. |
| 26 | +
|
| 27 | +.. contents:: Table of Contents |
| 28 | + :local: |
| 29 | + :depth: 1 |
| 30 | +""" |
| 31 | + |
| 32 | +import numpy as np |
| 33 | +import tvm |
| 34 | +from tvm import relax |
| 35 | + |
| 36 | +###################################################################### |
| 37 | +# Create IRModule |
| 38 | +# --------------- |
| 39 | +# IRModules can be initialized in various ways. We demonstrate a few of them |
| 40 | +# below. |
| 41 | + |
| 42 | +import torch |
| 43 | +from torch import fx, nn |
| 44 | +from tvm.relax.frontend.torch import from_fx |
| 45 | + |
| 46 | +###################################################################### |
| 47 | +# Import from existing models |
| 48 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 49 | +# The most common way to initialize an IRModule is to import from an existing |
| 50 | +# model. Apache TVM Unity accommodates imports from a range of frameworks, |
| 51 | +# such as PyTorch and ONNX. This tutorial solely demonstrates the import process |
| 52 | +# from PyTorch. |
| 53 | + |
| 54 | + |
| 55 | +# Create a dummy model |
| 56 | +class TorchModel(nn.Module): |
| 57 | + def __init__(self): |
| 58 | + super(TorchModel, self).__init__() |
| 59 | + self.fc1 = nn.Linear(784, 256) |
| 60 | + self.relu1 = nn.ReLU() |
| 61 | + self.fc2 = nn.Linear(256, 10) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + x = self.fc1(x) |
| 65 | + x = self.relu1(x) |
| 66 | + x = self.fc2(x) |
| 67 | + return x |
| 68 | + |
| 69 | + |
| 70 | +# Give the input shape and data type |
| 71 | +input_info = [((1, 784), "float32")] |
| 72 | + |
| 73 | +# Convert the model to IRModule |
| 74 | +with torch.no_grad(): |
| 75 | + torch_fx_model = fx.symbolic_trace(TorchModel()) |
| 76 | + mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True) |
| 77 | + |
| 78 | +mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch) |
| 79 | +# Print the IRModule |
| 80 | +mod_from_torch.show() |
| 81 | + |
| 82 | +###################################################################### |
| 83 | +# Write with Relax NN Module |
| 84 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 85 | +# Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users |
| 86 | +# write the IRModule directly. |
| 87 | + |
| 88 | +from tvm.relax.frontend import nn |
| 89 | + |
| 90 | + |
| 91 | +class RelaxModel(nn.Module): |
| 92 | + def __init__(self): |
| 93 | + super(RelaxModel, self).__init__() |
| 94 | + self.fc1 = nn.Linear(784, 256) |
| 95 | + self.relu1 = nn.ReLU() |
| 96 | + self.fc2 = nn.Linear(256, 10) |
| 97 | + |
| 98 | + def forward(self, x): |
| 99 | + x = self.fc1(x) |
| 100 | + x = self.relu1(x) |
| 101 | + x = self.fc2(x) |
| 102 | + return x |
| 103 | + |
| 104 | + |
| 105 | +mod_from_relax, params_from_relax = RelaxModel().export_tvm( |
| 106 | + {"forward": {"x": nn.spec.Tensor((1, 784), "float32")}} |
| 107 | +) |
| 108 | +mod_from_relax.show() |
| 109 | + |
| 110 | +###################################################################### |
| 111 | +# Create via TVMScript |
| 112 | +# ~~~~~~~~~~~~~~~~~~~~ |
| 113 | +# TVMScript is a Python-based DSL for IRModules. We are able to |
| 114 | +# directly output the IRModule in the TVMScript syntax, or alternatively, |
| 115 | +# parse the TVMScript to obtain an IRModule. |
| 116 | + |
| 117 | +from tvm.script import ir as I |
| 118 | +from tvm.script import relax as R |
| 119 | + |
| 120 | + |
| 121 | +@I.ir_module |
| 122 | +class TVMScriptModule: |
| 123 | + @R.function |
| 124 | + def main( |
| 125 | + x: R.Tensor((1, 784), dtype="float32"), |
| 126 | + fc1_weight: R.Tensor((256, 784), dtype="float32"), |
| 127 | + fc1_bias: R.Tensor((256,), dtype="float32"), |
| 128 | + fc2_weight: R.Tensor((10, 256), dtype="float32"), |
| 129 | + fc2_bias: R.Tensor((10,), dtype="float32"), |
| 130 | + ) -> R.Tensor((1, 10), dtype="float32"): |
| 131 | + R.func_attr({"num_input": 1}) |
| 132 | + with R.dataflow(): |
| 133 | + permute_dims = R.permute_dims(fc1_weight, axes=None) |
| 134 | + matmul = R.matmul(x, permute_dims, out_dtype="void") |
| 135 | + add = R.add(matmul, fc1_bias) |
| 136 | + relu = R.nn.relu(add) |
| 137 | + permute_dims1 = R.permute_dims(fc2_weight, axes=None) |
| 138 | + matmul1 = R.matmul(relu, permute_dims1, out_dtype="void") |
| 139 | + add1 = R.add(matmul1, fc2_bias) |
| 140 | + gv = add1 |
| 141 | + R.output(gv) |
| 142 | + return gv |
| 143 | + |
| 144 | + |
| 145 | +mod_from_script = TVMScriptModule |
| 146 | +mod_from_script.show() |
| 147 | + |
| 148 | +###################################################################### |
| 149 | +# Attributes of an IRModule |
| 150 | +# ------------------------- |
| 151 | +# An IRModule is a collection of functions, indexed by GlobalVars. |
| 152 | + |
| 153 | +mod = mod_from_torch |
| 154 | +print(mod.get_global_vars()) |
| 155 | + |
| 156 | +###################################################################### |
| 157 | +# We can access the functions in the IRModule by indexing with the GlobalVars |
| 158 | +# or their names |
| 159 | + |
| 160 | +# index by global var name |
| 161 | +print(mod["main"]) |
| 162 | +# index by global var, and checking they are the same function |
| 163 | +(gv,) = mod.get_global_vars() |
| 164 | +assert mod[gv] == mod["main"] |
| 165 | + |
| 166 | +###################################################################### |
| 167 | +# Transformations on IRModules |
| 168 | +# ---------------------------- |
| 169 | +# Transformations are the import component of Apache TVM Unity. One transformation |
| 170 | +# takes in an IRModule and outputs another IRModule. We can apply a sequence of |
| 171 | +# transformations to an IRModule to obtain a new IRModule. That is the common way to |
| 172 | +# optimize a model. |
| 173 | +# |
| 174 | +# In this getting started tutorial, we only demonstrate how to apply transformations |
| 175 | +# to an IRModule. For details of each transformation, please refer to the |
| 176 | +# :ref:`Transformation API Reference <api-relax-transformation>` |
| 177 | + |
| 178 | +###################################################################### |
| 179 | +# We first apply **LegalizeOps** transformation to the IRModule. This transformation |
| 180 | +# will convert the Relax module into a mixed stage, with both Relax and TensorIR function |
| 181 | +# within the same module. Meanwhile, the Relax operators will be converted into ``call_tir``. |
| 182 | + |
| 183 | +mod = mod_from_torch |
| 184 | +mod = relax.transform.LegalizeOps()(mod) |
| 185 | +mod.show() |
| 186 | + |
| 187 | +###################################################################### |
| 188 | +# After the transformation, there are much more functions inside the module. Let's print |
| 189 | +# the global vars again. |
| 190 | + |
| 191 | +print(mod.get_global_vars()) |
| 192 | + |
| 193 | +###################################################################### |
| 194 | +# Next, Apache TVM Unity provides a set of default transformation pipelines for users, |
| 195 | +# to simplify the transformation process. We can then apply the default pipeline to the module. |
| 196 | +# The default **zero** pipeline contains very fundamental transformations, including: |
| 197 | +# |
| 198 | +# - **LegalizeOps**: This transform converts the Relax operators into `call_tir` functions |
| 199 | +# with the corresponding TensorIR Functions. After this transform, the IRModule will |
| 200 | +# contain both Relax functions and TensorIR functions. |
| 201 | +# - **AnnotateTIROpPattern**: This transform annotates the pattern of the TensorIR functions, |
| 202 | +# preparing them for subsequent operator fusion. |
| 203 | +# - **FoldConstant**: This pass performs constant folding, optimizing operations |
| 204 | +# involving constants. |
| 205 | +# - **FuseOps and FuseTIR**: These two passes work together to fuse operators based on the |
| 206 | +# patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform |
| 207 | +# both Relax functions and TensorIR functions. |
| 208 | +# |
| 209 | +# .. note:: |
| 210 | +# |
| 211 | +# Here, we have applied **LegalizeOps** twice in the flow. The second time is useless but |
| 212 | +# harmless. |
| 213 | +# |
| 214 | +# Every passes can be duplicated in the flow, since we ensure the passes can handle all legal |
| 215 | +# IRModule inputs. This design can help users to construct their own pipeline. |
| 216 | + |
| 217 | +mod = relax.get_pipeline("zero")(mod) |
| 218 | +mod.show() |
| 219 | + |
| 220 | +###################################################################### |
| 221 | +# Deploy the IRModule Universally |
| 222 | +# ------------------------------- |
| 223 | +# After the optimization, we can compile the model into a TVM runtime module. |
| 224 | +# Notably, Apache TVM Unity provides the ability of universal deployment, which means |
| 225 | +# we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging |
| 226 | +# backends. |
| 227 | +# |
| 228 | +# Deploy on CPU |
| 229 | +# ~~~~~~~~~~~~~ |
| 230 | +# We can deploy the IRModule on CPU by specifying the target as ``llvm``. |
| 231 | + |
| 232 | +exec = relax.build(mod, target="llvm") |
| 233 | +dev = tvm.cpu() |
| 234 | +vm = relax.VirtualMachine(exec, dev) |
| 235 | + |
| 236 | +raw_data = np.random.rand(1, 784).astype("float32") |
| 237 | +data = tvm.nd.array(raw_data, dev) |
| 238 | +cpu_out = vm["main"](data, *params_from_torch["main"]).numpy() |
| 239 | +print(cpu_out) |
| 240 | + |
| 241 | +###################################################################### |
| 242 | +# Deploy on GPU |
| 243 | +# ~~~~~~~~~~~~~ |
| 244 | +# Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires |
| 245 | +# programs containing extra information, such as the thread bindings and shared memory |
| 246 | +# allocations. We need a further transformation to generate the GPU programs. |
| 247 | +# |
| 248 | +# We use ``DLight`` to generate the GPU programs. In this tutorial, we won't go into |
| 249 | +# the details of ``DLight``. |
| 250 | +# |
| 251 | + |
| 252 | +from tvm import dlight as dl |
| 253 | + |
| 254 | +with tvm.target.Target("cuda"): |
| 255 | + gpu_mod = dl.ApplyDefaultSchedule( |
| 256 | + dl.gpu.Matmul(), |
| 257 | + dl.gpu.Fallback(), |
| 258 | + )(mod) |
| 259 | + |
| 260 | +###################################################################### |
| 261 | +# Now we can compile the IRModule on GPU, the similar way as we did on CPU. |
| 262 | + |
| 263 | +exec = relax.build(gpu_mod, target="cuda") |
| 264 | +dev = tvm.device("cuda", 0) |
| 265 | +vm = relax.VirtualMachine(exec, dev) |
| 266 | +# Need to allocate data and params on GPU device |
| 267 | +data = tvm.nd.array(raw_data, dev) |
| 268 | +gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]] |
| 269 | +gpu_out = vm["main"](data, *gpu_params).numpy() |
| 270 | +print(gpu_out) |
| 271 | + |
| 272 | +# Check the correctness of the results |
| 273 | +assert np.allclose(cpu_out, gpu_out, atol=1e-3) |
| 274 | + |
| 275 | +###################################################################### |
| 276 | +# Deploy on Other Backends |
| 277 | +# ~~~~~~~~~~~~~~~~~~~~~~~~ |
| 278 | +# Apache TVM Unity also supports other backends, such as different kinds of GPUs |
| 279 | +# (Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other |
| 280 | +# emerging backends (e.g., WebAssembly). The deployment process is similar to the |
| 281 | +# GPU backend. |
0 commit comments