Skip to content

Commit cd34486

Browse files
author
Siyuan Feng
authored
[Doc] Customize Optimization (#17320)
[Doc] Customization Optimization
1 parent 3262f19 commit cd34486

File tree

2 files changed

+226
-0
lines changed

2 files changed

+226
-0
lines changed

docs/how_to/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
:maxdepth: 1
2020

2121
tutorials/e2e_opt_model
22+
tutorials/customize_opt
2223
tutorials/cross_compilation_and_rpc
2324
dev/index
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
.. _customize_opt:
20+
21+
Customize Optimization
22+
======================
23+
One main design goal of Apache TVM is to enable easy customization of the optimization pipeline
24+
for both research or development purposes and iterate the engineering optimizations. In this
25+
tutorial we will
26+
27+
.. contents:: Table of Contents
28+
:local:
29+
:depth: 1
30+
"""
31+
32+
######################################################################
33+
# Review Overall Flow
34+
# -------------------
35+
# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg
36+
# :align: center
37+
# :width: 80%
38+
#
39+
# The overall flow consists of the following steps:
40+
#
41+
# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained
42+
# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains
43+
# all the information needed for compilation, including high-level Relax functions for
44+
# computational graph, and low-level TensorIR functions for tensor program.
45+
# - **Perform Composable Optimizations**: Perform a series of optimization transformations,
46+
# such as graph optimizations, tensor program optimizations, and library dispatching.
47+
# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the
48+
# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.
49+
#
50+
51+
import os
52+
import tempfile
53+
import numpy as np
54+
import tvm
55+
from tvm import IRModule, relax
56+
from tvm.relax.frontend import nn
57+
58+
######################################################################
59+
# Composable IRModule Optimization
60+
# --------------------------------
61+
# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered
62+
# around IRModule optimization can be composed with existing pipelines. Note that each optimization
63+
# can focus on **part of the computation graph**, enabling partial lowering or partial optimization.
64+
#
65+
# In this tutorial, we will demonstrate how to optimize a model with Apache TVM Unity.
66+
67+
######################################################################
68+
# Prepare a Relax Module
69+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
70+
# We first prepare a Relax module. The module can be imported from other frameworks, constructed
71+
# with NN module frontend or TVMScript. Here we use a simple neural network model as an example.
72+
73+
74+
class RelaxModel(nn.Module):
75+
def __init__(self):
76+
super(RelaxModel, self).__init__()
77+
self.fc1 = nn.Linear(784, 256)
78+
self.relu1 = nn.ReLU()
79+
self.fc2 = nn.Linear(256, 10, bias=False)
80+
81+
def forward(self, x):
82+
x = self.fc1(x)
83+
x = self.relu1(x)
84+
x = self.fc2(x)
85+
return x
86+
87+
88+
input_shape = (1, 784)
89+
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
90+
mod.show()
91+
92+
######################################################################
93+
# Library Dispatch
94+
# ~~~~~~~~~~~~~~~~
95+
# We would like to quickly try out a variant of library optimization for certain platforms
96+
# (e.g., GPU). We can write a certain dispatching pass for the specific platform and
97+
# operator. Here we demonstrate how to dispatch the CUBLAS library for certain patterns.
98+
#
99+
# .. note::
100+
# This tutorial only demonstrates a single operator dispatching for CUBLAS, highlighting
101+
# the flexibility of the optimization pipeline. In real-world cases, we can import multiple
102+
# patterns and dispatch them to different kernels.
103+
104+
105+
# Import cublas pattern
106+
import tvm.relax.backend.contrib.cublas as _cublas
107+
108+
109+
# Define a new pass for CUBLAS dispatch
110+
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
111+
class CublasDispatch:
112+
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
113+
# Check if CUBLAS is enabled
114+
if not tvm.get_global_func("relax.ext.cublas", True):
115+
raise Exception("CUBLAS is not enabled.")
116+
117+
# Get interested patterns
118+
patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
119+
# Note in real-world cases, we usually get all patterns
120+
# patterns = relax.backend.get_patterns_with_prefix("cublas")
121+
122+
# Fuse ops by patterns and then run codegen
123+
mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
124+
mod = relax.transform.RunCodegen()(mod)
125+
return mod
126+
127+
128+
mod = CublasDispatch()(mod)
129+
mod.show()
130+
131+
######################################################################
132+
# After the dispatching pass, we can see that the first ``nn.Linear`` and ``nn.ReLU`` are fused
133+
# and rewritten to a ``call_dps_packed`` function which call the CUBLAS library. Notably, the
134+
# other part is not changed, which means we can selectively dispatch the optimization for
135+
# certain computation.
136+
137+
######################################################################
138+
# Auto Tuning
139+
# ~~~~~~~~~~~
140+
# Continuing from the previous example, we can further optimize the model with auto-tuning for
141+
# the **rest part of the computation**. Here we demonstrate how to use the meta-schedule to auto-tune
142+
# the model.
143+
#
144+
# We can use ``MetaScheduleTuneTIR`` pass to simply tuning the model, while ``MetaScheduleApplyDatabase``
145+
# pass to apply the best configuration to the model. The tuning process will generate search space,
146+
# tune the model and the following steps will apply the best configuration to the model. Before
147+
# running the passes, we need to lowering relax operator into TensorIR functions via ``LegalizeOps``
148+
#
149+
# .. note::
150+
#
151+
# To save CI time and avoid flakiness, we skip the tuning process in CI environment.
152+
#
153+
154+
device = tvm.cuda(0)
155+
target = tvm.target.Target.from_device(device)
156+
if os.getenv("CI", "") != "true":
157+
trials = 2000
158+
with target, tempfile.TemporaryDirectory() as tmp_dir:
159+
mod = tvm.ir.transform.Sequential(
160+
[
161+
relax.get_pipeline("zero"),
162+
relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
163+
relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
164+
]
165+
)(mod)
166+
167+
mod.show()
168+
169+
######################################################################
170+
# DLight Rules
171+
# ~~~~~~~~~~~~
172+
# DLight rules are a set of default rules for scheduling and optimization the kernel.
173+
# DLight rules are designed for fast compilation and **fair** performance. In some cases,
174+
# e.g. language model, DLight provides excellent performance, while for generic models,
175+
# it achieves a balance between performance and compilation time.
176+
177+
from tvm import dlight as dl
178+
179+
# Apply DLight rules
180+
with target:
181+
mod = tvm.ir.transform.Sequential(
182+
[
183+
relax.get_pipeline("zero"),
184+
dl.ApplyDefaultSchedule( # pylint: disable=not-callable
185+
dl.gpu.Matmul(),
186+
dl.gpu.GEMV(),
187+
dl.gpu.Reduction(),
188+
dl.gpu.GeneralReduction(),
189+
dl.gpu.Fallback(),
190+
),
191+
]
192+
)(mod)
193+
194+
mod.show()
195+
196+
######################################################################
197+
# .. note::
198+
#
199+
# This tutorial focuses on the demonstration of the optimization pipeline, instead of
200+
# pushing the performance to the limit. The current optimization may not be the best.
201+
202+
203+
######################################################################
204+
# Deploy the Optimized Model
205+
# --------------------------
206+
# We can build and deploy the optimized model to the TVM runtime.
207+
208+
ex = relax.build(mod, target="cuda")
209+
dev = tvm.device("cuda", 0)
210+
vm = relax.VirtualMachine(ex, dev)
211+
# Need to allocate data and params on GPU device
212+
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
213+
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
214+
gpu_out = vm["forward"](data, *gpu_params).numpy()
215+
print(gpu_out)
216+
217+
218+
######################################################################
219+
# Summary
220+
# -------
221+
# This tutorial demonstrates how to customize the optimization pipeline for ML models in Apache TVM.
222+
# We can easily compose the optimization passes and customize the optimization for different parts
223+
# of the computation graph. The flexibility of the optimization pipeline enables us to quickly
224+
# iterate the optimization and improve the performance of the model.
225+
#

0 commit comments

Comments
 (0)