|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +""" |
| 19 | +.. _optimize_model: |
| 20 | +
|
| 21 | +End-to-End Optimize Model |
| 22 | +========================= |
| 23 | +This tutorial demonstrates how to optimize a machine learning model using Apache TVM. We will |
| 24 | +use a pre-trained ResNet-18 model from PyTorch and end-to-end optimize it using TVM's Relax API. |
| 25 | +Please note that default end-to-end optimization may not suit complex models. |
| 26 | +""" |
| 27 | + |
| 28 | +###################################################################### |
| 29 | +# Preparation |
| 30 | +# ----------- |
| 31 | +# First, we prepare the model and input information. We use a pre-trained ResNet-18 model from |
| 32 | +# PyTorch. |
| 33 | + |
| 34 | +import numpy as np |
| 35 | +import torch |
| 36 | +from torch import fx |
| 37 | +from torchvision.models.resnet import ResNet18_Weights, resnet18 |
| 38 | + |
| 39 | +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) |
| 40 | + |
| 41 | +###################################################################### |
| 42 | +# Review Overall Flow |
| 43 | +# ------------------- |
| 44 | +# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg |
| 45 | +# :align: center |
| 46 | +# :width: 80% |
| 47 | +# |
| 48 | +# The overall flow consists of the following steps: |
| 49 | +# |
| 50 | +# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained |
| 51 | +# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains |
| 52 | +# all the information needed for compilation, including high-level Relax functions for |
| 53 | +# computational graph, and low-level TensorIR functions for tensor program. |
| 54 | +# - **Perform Composable Optimizations**: Perform a series of optimization transformations, |
| 55 | +# such as graph optimizations, tensor program optimizations, and library dispatching. |
| 56 | +# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the |
| 57 | +# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators. |
| 58 | +# |
| 59 | + |
| 60 | + |
| 61 | +###################################################################### |
| 62 | +# Convert the model to IRModule |
| 63 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 64 | +# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further |
| 65 | +# optimization. Besides the model, we also need to provide the input shape and data type. |
| 66 | + |
| 67 | +import tvm |
| 68 | +from tvm import relax |
| 69 | +from tvm.relax.frontend.torch import from_fx |
| 70 | + |
| 71 | +torch_model = resnet18(weights=ResNet18_Weights.DEFAULT) |
| 72 | + |
| 73 | +# Give the input shape and data type |
| 74 | +input_info = [((1, 3, 224, 224), "float32")] |
| 75 | + |
| 76 | +# Convert the model to IRModule |
| 77 | +with torch.no_grad(): |
| 78 | + torch_fx_model = fx.symbolic_trace(torch_model) |
| 79 | + mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True) |
| 80 | + |
| 81 | +mod, params = relax.frontend.detach_params(mod) |
| 82 | +mod.show() |
| 83 | + |
| 84 | +###################################################################### |
| 85 | +# IRModule Optimization |
| 86 | +# --------------------- |
| 87 | +# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered |
| 88 | +# around IRModule optimization can be composed with existing pipelines. Note that each |
| 89 | +# transformation can be combined as an optimization pipeline via ``tvm.ir.transform.Sequential``. |
| 90 | +# |
| 91 | +# In this tutorial, we focus on the end-to-end optimization of the model via auto-tuning. We |
| 92 | +# leverage MetaSchedule to tune the model and store the tuning logs to the database. We also |
| 93 | +# apply the database to the model to get the best performance. |
| 94 | +# |
| 95 | +# .. note:: |
| 96 | +# |
| 97 | +# To save CI time, we disable the tuning by default and only load the pre-tuned log. |
| 98 | +# You can enable it by turning on the ``enable_tune`` flag. |
| 99 | +# |
| 100 | + |
| 101 | +TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed |
| 102 | +enable_tune = False # Enable this flag to tune the model |
| 103 | +target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device |
| 104 | +work_dir = "tuning_logs" |
| 105 | + |
| 106 | +with target: |
| 107 | + mod = tvm.ir.transform.Sequential( |
| 108 | + [ |
| 109 | + # Convert BatchNorm into a sequence of simpler ops for fusion |
| 110 | + relax.transform.DecomposeOpsForInference(), |
| 111 | + # Canonicalize the bindings |
| 112 | + relax.transform.CanonicalizeBindings(), |
| 113 | + # Run default optimization pipeline |
| 114 | + relax.get_pipeline("zero"), |
| 115 | + # Tune the model and store the log to database |
| 116 | + relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS) |
| 117 | + if enable_tune |
| 118 | + else tvm.transform.Sequential([]), |
| 119 | + # Apply the database |
| 120 | + relax.transform.MetaScheduleApplyDatabase(work_dir), |
| 121 | + ] |
| 122 | + )(mod) |
| 123 | + |
| 124 | +# Only show the main function |
| 125 | +mod["main"].show() |
| 126 | + |
| 127 | +###################################################################### |
| 128 | +# Build and Deploy |
| 129 | +# ---------------- |
| 130 | +# Finally, we build the optimized model and deploy it to the target device. |
| 131 | + |
| 132 | +ex = relax.build(mod, target="cuda") |
| 133 | +dev = tvm.device("cuda", 0) |
| 134 | +vm = relax.VirtualMachine(ex, dev) |
| 135 | +# Need to allocate data and params on GPU device |
| 136 | +gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev) |
| 137 | +gpu_params = [tvm.nd.array(p, dev) for p in params["main"]] |
| 138 | +gpu_out = vm["main"](gpu_data, *gpu_params).numpy() |
| 139 | + |
| 140 | +print(gpu_out.shape) |
0 commit comments