Skip to content

Commit e1da465

Browse files
author
Siyuan Feng
authored
[Doc] IRModule (#17298)
1 parent 9e865b4 commit e1da465

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
.. _ir_module:
20+
21+
IRModule
22+
========
23+
This tutorial presents the core abstraction of Apache TVM Unity, the IRModule.
24+
The IRModule encompasses the **entirety** of the ML models, incorporating the
25+
computational graph, tensor programs, and potential calls to external libraries.
26+
27+
.. contents:: Table of Contents
28+
:local:
29+
:depth: 1
30+
"""
31+
32+
import numpy as np
33+
import tvm
34+
from tvm import relax
35+
36+
######################################################################
37+
# Create IRModule
38+
# ---------------
39+
# IRModules can be initialized in various ways. We demonstrate a few of them
40+
# below.
41+
42+
import torch
43+
from torch import fx, nn
44+
from tvm.relax.frontend.torch import from_fx
45+
46+
######################################################################
47+
# Import from existing models
48+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
49+
# The most common way to initialize an IRModule is to import from an existing
50+
# model. Apache TVM Unity accommodates imports from a range of frameworks,
51+
# such as PyTorch and ONNX. This tutorial solely demonstrates the import process
52+
# from PyTorch.
53+
54+
55+
# Create a dummy model
56+
class TorchModel(nn.Module):
57+
def __init__(self):
58+
super(TorchModel, self).__init__()
59+
self.fc1 = nn.Linear(784, 256)
60+
self.relu1 = nn.ReLU()
61+
self.fc2 = nn.Linear(256, 10)
62+
63+
def forward(self, x):
64+
x = self.fc1(x)
65+
x = self.relu1(x)
66+
x = self.fc2(x)
67+
return x
68+
69+
70+
# Give the input shape and data type
71+
input_info = [((1, 784), "float32")]
72+
73+
# Convert the model to IRModule
74+
with torch.no_grad():
75+
torch_fx_model = fx.symbolic_trace(TorchModel())
76+
mod_from_torch = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
77+
78+
mod_from_torch, params_from_torch = relax.frontend.detach_params(mod_from_torch)
79+
# Print the IRModule
80+
mod_from_torch.show()
81+
82+
######################################################################
83+
# Write with Relax NN Module
84+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
85+
# Apache TVM Unity also provides a set of PyTorch-liked APIs, to help users
86+
# write the IRModule directly.
87+
88+
from tvm.relax.frontend import nn
89+
90+
91+
class RelaxModel(nn.Module):
92+
def __init__(self):
93+
super(RelaxModel, self).__init__()
94+
self.fc1 = nn.Linear(784, 256)
95+
self.relu1 = nn.ReLU()
96+
self.fc2 = nn.Linear(256, 10)
97+
98+
def forward(self, x):
99+
x = self.fc1(x)
100+
x = self.relu1(x)
101+
x = self.fc2(x)
102+
return x
103+
104+
105+
mod_from_relax, params_from_relax = RelaxModel().export_tvm(
106+
{"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
107+
)
108+
mod_from_relax.show()
109+
110+
######################################################################
111+
# Create via TVMScript
112+
# ~~~~~~~~~~~~~~~~~~~~
113+
# TVMScript is a Python-based DSL for IRModules. We are able to
114+
# directly output the IRModule in the TVMScript syntax, or alternatively,
115+
# parse the TVMScript to obtain an IRModule.
116+
117+
from tvm.script import ir as I
118+
from tvm.script import relax as R
119+
120+
121+
@I.ir_module
122+
class TVMScriptModule:
123+
@R.function
124+
def main(
125+
x: R.Tensor((1, 784), dtype="float32"),
126+
fc1_weight: R.Tensor((256, 784), dtype="float32"),
127+
fc1_bias: R.Tensor((256,), dtype="float32"),
128+
fc2_weight: R.Tensor((10, 256), dtype="float32"),
129+
fc2_bias: R.Tensor((10,), dtype="float32"),
130+
) -> R.Tensor((1, 10), dtype="float32"):
131+
R.func_attr({"num_input": 1})
132+
with R.dataflow():
133+
permute_dims = R.permute_dims(fc1_weight, axes=None)
134+
matmul = R.matmul(x, permute_dims, out_dtype="void")
135+
add = R.add(matmul, fc1_bias)
136+
relu = R.nn.relu(add)
137+
permute_dims1 = R.permute_dims(fc2_weight, axes=None)
138+
matmul1 = R.matmul(relu, permute_dims1, out_dtype="void")
139+
add1 = R.add(matmul1, fc2_bias)
140+
gv = add1
141+
R.output(gv)
142+
return gv
143+
144+
145+
mod_from_script = TVMScriptModule
146+
mod_from_script.show()
147+
148+
######################################################################
149+
# Attributes of an IRModule
150+
# -------------------------
151+
# An IRModule is a collection of functions, indexed by GlobalVars.
152+
153+
mod = mod_from_torch
154+
print(mod.get_global_vars())
155+
156+
######################################################################
157+
# We can access the functions in the IRModule by indexing with the GlobalVars
158+
# or their names
159+
160+
# index by global var name
161+
print(mod["main"])
162+
# index by global var, and checking they are the same function
163+
(gv,) = mod.get_global_vars()
164+
assert mod[gv] == mod["main"]
165+
166+
######################################################################
167+
# Transformations on IRModules
168+
# ----------------------------
169+
# Transformations are the import component of Apache TVM Unity. One transformation
170+
# takes in an IRModule and outputs another IRModule. We can apply a sequence of
171+
# transformations to an IRModule to obtain a new IRModule. That is the common way to
172+
# optimize a model.
173+
#
174+
# In this getting started tutorial, we only demonstrate how to apply transformations
175+
# to an IRModule. For details of each transformation, please refer to the
176+
# :ref:`Transformation API Reference <api-relax-transformation>`
177+
178+
######################################################################
179+
# We first apply **LegalizeOps** transformation to the IRModule. This transformation
180+
# will convert the Relax module into a mixed stage, with both Relax and TensorIR function
181+
# within the same module. Meanwhile, the Relax operators will be converted into ``call_tir``.
182+
183+
mod = mod_from_torch
184+
mod = relax.transform.LegalizeOps()(mod)
185+
mod.show()
186+
187+
######################################################################
188+
# After the transformation, there are much more functions inside the module. Let's print
189+
# the global vars again.
190+
191+
print(mod.get_global_vars())
192+
193+
######################################################################
194+
# Next, Apache TVM Unity provides a set of default transformation pipelines for users,
195+
# to simplify the transformation process. We can then apply the default pipeline to the module.
196+
# The default **zero** pipeline contains very fundamental transformations, including:
197+
#
198+
# - **LegalizeOps**: This transform converts the Relax operators into `call_tir` functions
199+
# with the corresponding TensorIR Functions. After this transform, the IRModule will
200+
# contain both Relax functions and TensorIR functions.
201+
# - **AnnotateTIROpPattern**: This transform annotates the pattern of the TensorIR functions,
202+
# preparing them for subsequent operator fusion.
203+
# - **FoldConstant**: This pass performs constant folding, optimizing operations
204+
# involving constants.
205+
# - **FuseOps and FuseTIR**: These two passes work together to fuse operators based on the
206+
# patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform
207+
# both Relax functions and TensorIR functions.
208+
#
209+
# .. note::
210+
#
211+
# Here, we have applied **LegalizeOps** twice in the flow. The second time is useless but
212+
# harmless.
213+
#
214+
# Every passes can be duplicated in the flow, since we ensure the passes can handle all legal
215+
# IRModule inputs. This design can help users to construct their own pipeline.
216+
217+
mod = relax.get_pipeline("zero")(mod)
218+
mod.show()
219+
220+
######################################################################
221+
# Deploy the IRModule Universally
222+
# -------------------------------
223+
# After the optimization, we can compile the model into a TVM runtime module.
224+
# Notably, Apache TVM Unity provides the ability of universal deployment, which means
225+
# we can deploy the same IRModule on different backends, including CPU, GPU, and other emerging
226+
# backends.
227+
#
228+
# Deploy on CPU
229+
# ~~~~~~~~~~~~~
230+
# We can deploy the IRModule on CPU by specifying the target as ``llvm``.
231+
232+
exec = relax.build(mod, target="llvm")
233+
dev = tvm.cpu()
234+
vm = relax.VirtualMachine(exec, dev)
235+
236+
raw_data = np.random.rand(1, 784).astype("float32")
237+
data = tvm.nd.array(raw_data, dev)
238+
cpu_out = vm["main"](data, *params_from_torch["main"]).numpy()
239+
print(cpu_out)
240+
241+
######################################################################
242+
# Deploy on GPU
243+
# ~~~~~~~~~~~~~
244+
# Besides, CPU backend, we can also deploy the IRModule on GPU. GPU requires
245+
# programs containing extra information, such as the thread bindings and shared memory
246+
# allocations. We need a further transformation to generate the GPU programs.
247+
#
248+
# We use ``DLight`` to generate the GPU programs. In this tutorial, we won't go into
249+
# the details of ``DLight``.
250+
#
251+
252+
from tvm import dlight as dl
253+
254+
with tvm.target.Target("cuda"):
255+
gpu_mod = dl.ApplyDefaultSchedule(
256+
dl.gpu.Matmul(),
257+
dl.gpu.Fallback(),
258+
)(mod)
259+
260+
######################################################################
261+
# Now we can compile the IRModule on GPU, the similar way as we did on CPU.
262+
263+
exec = relax.build(gpu_mod, target="cuda")
264+
dev = tvm.device("cuda", 0)
265+
vm = relax.VirtualMachine(exec, dev)
266+
# Need to allocate data and params on GPU device
267+
data = tvm.nd.array(raw_data, dev)
268+
gpu_params = [tvm.nd.array(p, dev) for p in params_from_torch["main"]]
269+
gpu_out = vm["main"](data, *gpu_params).numpy()
270+
print(gpu_out)
271+
272+
# Check the correctness of the results
273+
assert np.allclose(cpu_out, gpu_out, atol=1e-3)
274+
275+
######################################################################
276+
# Deploy on Other Backends
277+
# ~~~~~~~~~~~~~~~~~~~~~~~~
278+
# Apache TVM Unity also supports other backends, such as different kinds of GPUs
279+
# (Metal, ROCm, Vulkan and OpenCL), different kinds of CPUs (x86, ARM), and other
280+
# emerging backends (e.g., WebAssembly). The deployment process is similar to the
281+
# GPU backend.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ driving its costs down.
3434

3535
install/index
3636
get_started/tutorials/quick_start
37+
get_started/tutorials/ir_module
3738
contribute/index
3839

3940
.. toctree::

0 commit comments

Comments
 (0)