|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +""" |
| 19 | +.. _customize_opt: |
| 20 | +
|
| 21 | +Customize Optimization |
| 22 | +====================== |
| 23 | +One main design goal of Apache TVM is to enable easy customization of the optimization pipeline |
| 24 | +for both research or development purposes and iterate the engineering optimizations. In this |
| 25 | +tutorial we will |
| 26 | +
|
| 27 | +.. contents:: Table of Contents |
| 28 | + :local: |
| 29 | + :depth: 1 |
| 30 | +""" |
| 31 | + |
| 32 | +###################################################################### |
| 33 | +# Review Overall Flow |
| 34 | +# ------------------- |
| 35 | +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg |
| 36 | +# :align: center |
| 37 | +# :width: 80% |
| 38 | +# |
| 39 | +# The overall flow consists of the following steps: |
| 40 | +# |
| 41 | +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained |
| 42 | +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains |
| 43 | +# all the information needed for compilation, including high-level Relax functions for |
| 44 | +# computational graph, and low-level TensorIR functions for tensor program. |
| 45 | +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, |
| 46 | +# such as graph optimizations, tensor program optimizations, and library dispatching. |
| 47 | +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the |
| 48 | +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. |
| 49 | +# |
| 50 | + |
| 51 | +import os |
| 52 | +import tempfile |
| 53 | +import numpy as np |
| 54 | +import tvm |
| 55 | +from tvm import IRModule, relax |
| 56 | +from tvm.relax.frontend import nn |
| 57 | + |
| 58 | +###################################################################### |
| 59 | +# Composable IRModule Optimization |
| 60 | +# -------------------------------- |
| 61 | +# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered |
| 62 | +# around IRModule optimization can be composed with existing pipelines. Note that each optimization |
| 63 | +# can focus on **part of the computation graph**, enabling partial lowering or partial optimization. |
| 64 | +# |
| 65 | +# In this tutorial, we will demonstrate how to optimize a model with Apache TVM Unity. |
| 66 | + |
| 67 | +###################################################################### |
| 68 | +# Prepare a Relax Module |
| 69 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 70 | +# We first prepare a Relax module. The module can be imported from other frameworks, constructed |
| 71 | +# with NN module frontend or TVMScript. Here we use a simple neural network model as an example. |
| 72 | + |
| 73 | + |
| 74 | +class RelaxModel(nn.Module): |
| 75 | + def __init__(self): |
| 76 | + super(RelaxModel, self).__init__() |
| 77 | + self.fc1 = nn.Linear(784, 256) |
| 78 | + self.relu1 = nn.ReLU() |
| 79 | + self.fc2 = nn.Linear(256, 10, bias=False) |
| 80 | + |
| 81 | + def forward(self, x): |
| 82 | + x = self.fc1(x) |
| 83 | + x = self.relu1(x) |
| 84 | + x = self.fc2(x) |
| 85 | + return x |
| 86 | + |
| 87 | + |
| 88 | +input_shape = (1, 784) |
| 89 | +mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}}) |
| 90 | +mod.show() |
| 91 | + |
| 92 | +###################################################################### |
| 93 | +# Library Dispatch |
| 94 | +# ~~~~~~~~~~~~~~~~ |
| 95 | +# We would like to quickly try out a variant of library optimization for certain platforms |
| 96 | +# (e.g., GPU). We can write a certain dispatching pass for the specific platform and |
| 97 | +# operator. Here we demonstrate how to dispatch the CUBLAS library for certain patterns. |
| 98 | +# |
| 99 | +# .. note:: |
| 100 | +# This tutorial only demonstrates a single operator dispatching for CUBLAS, highlighting |
| 101 | +# the flexibility of the optimization pipeline. In real-world cases, we can import multiple |
| 102 | +# patterns and dispatch them to different kernels. |
| 103 | + |
| 104 | + |
| 105 | +# Import cublas pattern |
| 106 | +import tvm.relax.backend.contrib.cublas as _cublas |
| 107 | + |
| 108 | + |
| 109 | +# Define a new pass for CUBLAS dispatch |
| 110 | +@tvm.transform.module_pass(opt_level=0, name="CublasDispatch") |
| 111 | +class CublasDispatch: |
| 112 | + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: |
| 113 | + # Check if CUBLAS is enabled |
| 114 | + if not tvm.get_global_func("relax.ext.cublas", True): |
| 115 | + raise Exception("CUBLAS is not enabled.") |
| 116 | + |
| 117 | + # Get interested patterns |
| 118 | + patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")] |
| 119 | + # Note in real-world cases, we usually get all patterns |
| 120 | + # patterns = relax.backend.get_patterns_with_prefix("cublas") |
| 121 | + |
| 122 | + # Fuse ops by patterns and then run codegen |
| 123 | + mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod) |
| 124 | + mod = relax.transform.RunCodegen()(mod) |
| 125 | + return mod |
| 126 | + |
| 127 | + |
| 128 | +mod = CublasDispatch()(mod) |
| 129 | +mod.show() |
| 130 | + |
| 131 | +###################################################################### |
| 132 | +# After the dispatching pass, we can see that the first ``nn.Linear`` and ``nn.ReLU`` are fused |
| 133 | +# and rewritten to a ``call_dps_packed`` function which call the CUBLAS library. Notably, the |
| 134 | +# other part is not changed, which means we can selectively dispatch the optimization for |
| 135 | +# certain computation. |
| 136 | + |
| 137 | +###################################################################### |
| 138 | +# Auto Tuning |
| 139 | +# ~~~~~~~~~~~ |
| 140 | +# Continuing from the previous example, we can further optimize the model with auto-tuning for |
| 141 | +# the **rest part of the computation**. Here we demonstrate how to use the meta-schedule to auto-tune |
| 142 | +# the model. |
| 143 | +# |
| 144 | +# We can use ``MetaScheduleTuneTIR`` pass to simply tuning the model, while ``MetaScheduleApplyDatabase`` |
| 145 | +# pass to apply the best configuration to the model. The tuning process will generate search space, |
| 146 | +# tune the model and the following steps will apply the best configuration to the model. Before |
| 147 | +# running the passes, we need to lowering relax operator into TensorIR functions via ``LegalizeOps`` |
| 148 | +# |
| 149 | +# .. note:: |
| 150 | +# |
| 151 | +# To save CI time and avoid flakiness, we skip the tuning process in CI environment. |
| 152 | +# |
| 153 | + |
| 154 | +device = tvm.cuda(0) |
| 155 | +target = tvm.target.Target.from_device(device) |
| 156 | +if os.getenv("CI", "") != "true": |
| 157 | + trials = 2000 |
| 158 | + with target, tempfile.TemporaryDirectory() as tmp_dir: |
| 159 | + mod = tvm.ir.transform.Sequential( |
| 160 | + [ |
| 161 | + relax.get_pipeline("zero"), |
| 162 | + relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials), |
| 163 | + relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir), |
| 164 | + ] |
| 165 | + )(mod) |
| 166 | + |
| 167 | + mod.show() |
| 168 | + |
| 169 | +###################################################################### |
| 170 | +# DLight Rules |
| 171 | +# ~~~~~~~~~~~~ |
| 172 | +# DLight rules are a set of default rules for scheduling and optimization the kernel. |
| 173 | +# DLight rules are designed for fast compilation and **fair** performance. In some cases, |
| 174 | +# e.g. language model, DLight provides excellent performance, while for generic models, |
| 175 | +# it achieves a balance between performance and compilation time. |
| 176 | + |
| 177 | +from tvm import dlight as dl |
| 178 | + |
| 179 | +# Apply DLight rules |
| 180 | +with target: |
| 181 | + mod = tvm.ir.transform.Sequential( |
| 182 | + [ |
| 183 | + relax.get_pipeline("zero"), |
| 184 | + dl.ApplyDefaultSchedule( # pylint: disable=not-callable |
| 185 | + dl.gpu.Matmul(), |
| 186 | + dl.gpu.GEMV(), |
| 187 | + dl.gpu.Reduction(), |
| 188 | + dl.gpu.GeneralReduction(), |
| 189 | + dl.gpu.Fallback(), |
| 190 | + ), |
| 191 | + ] |
| 192 | + )(mod) |
| 193 | + |
| 194 | +mod.show() |
| 195 | + |
| 196 | +###################################################################### |
| 197 | +# .. note:: |
| 198 | +# |
| 199 | +# This tutorial focuses on the demonstration of the optimization pipeline, instead of |
| 200 | +# pushing the performance to the limit. The current optimization may not be the best. |
| 201 | + |
| 202 | + |
| 203 | +###################################################################### |
| 204 | +# Deploy the Optimized Model |
| 205 | +# -------------------------- |
| 206 | +# We can build and deploy the optimized model to the TVM runtime. |
| 207 | + |
| 208 | +ex = relax.build(mod, target="cuda") |
| 209 | +dev = tvm.device("cuda", 0) |
| 210 | +vm = relax.VirtualMachine(ex, dev) |
| 211 | +# Need to allocate data and params on GPU device |
| 212 | +data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev) |
| 213 | +gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params] |
| 214 | +gpu_out = vm["forward"](data, *gpu_params).numpy() |
| 215 | +print(gpu_out) |
| 216 | + |
| 217 | + |
| 218 | +###################################################################### |
| 219 | +# Summary |
| 220 | +# ------- |
| 221 | +# This tutorial demonstrates how to customize the optimization pipeline for ML models in Apache TVM. |
| 222 | +# We can easily compose the optimization passes and customize the optimization for different parts |
| 223 | +# of the computation graph. The flexibility of the optimization pipeline enables us to quickly |
| 224 | +# iterate the optimization and improve the performance of the model. |
| 225 | +# |
0 commit comments