Skip to content

Commit e3fb612

Browse files
author
Siyuan Feng
committed
[Doc] Refactor How-To
This PR refactors the how-to section and add new tutorials of `end-to-end optimization model`
1 parent 3138328 commit e3fb612

File tree

15 files changed

+262
-86
lines changed

15 files changed

+262
-86
lines changed

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func):
423423
tvm_path.joinpath("vta", "tutorials"),
424424
# New tutorial structure under docs folder
425425
tvm_path.joinpath("docs", "get_started", "tutorials"),
426+
tvm_path.joinpath("docs", "how_to", "tutorials"),
426427
]
427428

428429
gallery_dirs = [
@@ -440,6 +441,7 @@ def jupyter_notebook(script_blocks, gallery_conf, target_dir, real_func):
440441
"topic/vta/tutorials",
441442
# New tutorial structure under docs folder
442443
"get_started/tutorials/",
444+
"how_to/tutorials/",
443445
]
444446

445447

docs/dev/how_to/how_to.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,3 @@ various areas of the TVM stack.
2929
relay_add_op
3030
relay_add_pass
3131
relay_bring_your_own_codegen
32-
pytest_target_parametrization
33-
setup_rpc_system

docs/how_to/dev/index.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
Develope Apache TVM
19+
===================
20+
This section contains a collection of tips about how to work on
21+
various areas of the TVM stack.
22+
23+
.. toctree::
24+
:maxdepth: 1
25+
26+
pytest_target_parametrization
27+
setup_rpc_system
28+
../../errors
File renamed without changes.

docs/dev/how_to/setup_rpc_system.rst renamed to docs/how_to/dev/setup_rpc_system.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ In our community, there is multiple RPC server implementations, e.g., ``apps/and
7676

7777
RPC server need to be run on device machine, and it usually will depend on xPU driver, the enhanced TVM runtime with xPU support, and other libraries, so please setup the dependent components first, e.g., install the KMD driver, ensure the required dynamic libraries can be found from environment variable ``LD_LIBRARY_PATH``.
7878

79-
If the required compilation environment can be setup on your device machine, i.e., you needn't to do the cross compilation, then just follow the instruction of `<https://tvm.apache.org/docs/install/from_source.html>`_ to compile the TVM runtime and directly jump to the step :ref:`luanch-rpc-server`.
79+
If the required compilation environment can be setup on your device machine, i.e., you needn't to do the cross compilation, then just follow the instruction of `<https://tvm.apache.org/docs/install/from_source.html>`_ to compile the TVM runtime and directly jump to the step :ref:`launch-rpc-server`.
8080

8181
1. Cross Compile TVM Runtime
8282
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -134,9 +134,9 @@ Then copy the compress package ``tvm_runtime.tar.gz`` to your concrete device ma
134134
$ export PYTHONPATH=`pwd`/python:${PYTHONPATH}
135135
136136
137-
.. _luanch-rpc-server:
137+
.. _launch-rpc-server:
138138

139-
3. Luanch RPC Server
139+
3. Launch RPC Server
140140
^^^^^^^^^^^^^^^^^^^^
141141

142142
The RPC server can be launched on your device machine through the commands like something below, please modify the *RPC_TRACKER_IP*, *RPC_TRACKER_PORT*, *RPC_PROXY_IP*, *RPC_PROXY_PORT*, and *RPC_KEY* according to your concrete environment.

docs/how_to/index.rst

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,9 @@
1515
specific language governing permissions and limitations
1616
under the License.
1717
18-
How To Guides
19-
=============
20-
21-
These user-focused "how to" guides are designed to help you find answers to
22-
specific questions, like "How do I compile a model?" or "How to I optimize a
23-
schedule with tesor expressions?"
24-
2518
.. toctree::
2619
:maxdepth: 1
2720

28-
compile_models/index
29-
deploy/index
30-
work_with_relay/index
31-
work_with_schedules/index
32-
optimize_operators/index
33-
tune_with_autotvm/index
34-
tune_with_autoscheduler/index
35-
work_with_microtvm/index
36-
extend_tvm/index
37-
profile/index
38-
../errors
39-
../faq
21+
tutorials/e2e_opt_model
22+
tutorials/cross_compilation_and_rpc
23+
dev/index

docs/how_to/legacy_index.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
How To Guides
19+
=============
20+
21+
These user-focused "how to" guides are designed to help you find answers to
22+
specific questions, like "How do I compile a model?" or "How to I optimize a
23+
schedule with tesor expressions?"
24+
25+
.. toctree::
26+
:maxdepth: 1
27+
28+
compile_models/index
29+
deploy/index
30+
work_with_relay/index
31+
work_with_schedules/index
32+
optimize_operators/index
33+
tune_with_autotvm/index
34+
tune_with_autoscheduler/index
35+
work_with_microtvm/index
36+
extend_tvm/index
37+
profile/index
38+
../faq

docs/how_to/tutorials/README.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
HOW TO
2+
------
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
.. _optimize_model:
20+
21+
End-to-End Optimize Model
22+
=========================
23+
This tutorial demonstrates how to optimize a machine learning model using Apache TVM. We will
24+
use a pre-trained ResNet-18 model from PyTorch and end-to-end optimize it using TVM's Relax API.
25+
Please note that default end-to-end optimization may not suit complex models.
26+
"""
27+
28+
######################################################################
29+
# Preparation
30+
# -----------
31+
# First, we prepare the model and input information. We use a pre-trained ResNet-18 model from
32+
# PyTorch.
33+
34+
import numpy as np
35+
import torch
36+
from torch import fx
37+
from torchvision.models.resnet import ResNet18_Weights, resnet18
38+
39+
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
40+
41+
######################################################################
42+
# Review Overall Flow
43+
# -------------------
44+
# .. figure:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg
45+
# :align: center
46+
# :width: 80%
47+
#
48+
# The overall flow consists of the following steps:
49+
#
50+
# - **Construct or Import a Model**: Construct a neural network model or import a pre-trained
51+
# model from other frameworks (e.g. PyTorch, ONNX), and create the TVM IRModule, which contains
52+
# all the information needed for compilation, including high-level Relax functions for
53+
# computational graph, and low-level TensorIR functions for tensor program.
54+
# - **Perform Composable Optimizations**: Perform a series of optimization transformations,
55+
# such as graph optimizations, tensor program optimizations, and library dispatching.
56+
# - **Build and Universal Deployment**: Build the optimized model to a deployable module to the
57+
# universal runtime, and execute it on different devices, such as CPU, GPU, or other accelerators.
58+
#
59+
60+
61+
######################################################################
62+
# Convert the model to IRModule
63+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
64+
# Next step, we convert the model to an IRModule using the Relax frontend for PyTorch for further
65+
# optimization. Besides the model, we also need to provide the input shape and data type.
66+
67+
import tvm
68+
from tvm import relax
69+
from tvm.relax.frontend.torch import from_fx
70+
71+
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
72+
73+
# Give the input shape and data type
74+
input_info = [((1, 3, 224, 224), "float32")]
75+
76+
# Convert the model to IRModule
77+
with torch.no_grad():
78+
torch_fx_model = fx.symbolic_trace(torch_model)
79+
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
80+
81+
mod, params = relax.frontend.detach_params(mod)
82+
mod.show()
83+
84+
######################################################################
85+
# IRModule Optimization
86+
# ---------------------
87+
# Apache TVM Unity provides a flexible way to optimize the IRModule. Everything centered
88+
# around IRModule optimization can be composed with existing pipelines. Note that each
89+
# transformation can be combined as an optimization pipeline via ``tvm.ir.transform.Sequential``.
90+
#
91+
# In this tutorial, we focus on the end-to-end optimization of the model via auto-tuning. We
92+
# leverage MetaSchedule to tune the model and store the tuning logs to the database. We also
93+
# apply the database to the model to get the best performance.
94+
#
95+
# .. note::
96+
#
97+
# To save CI time, we disable the tuning by default and only load the pre-tuned log.
98+
# You can enable it by turning on the ``enable_tune`` flag.
99+
#
100+
101+
TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed
102+
enable_tune = False # Enable this flag to tune the model
103+
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device
104+
work_dir = "tuning_logs"
105+
106+
with target:
107+
mod = tvm.ir.transform.Sequential(
108+
[
109+
# Convert BatchNorm into a sequence of simpler ops for fusion
110+
relax.transform.DecomposeOpsForInference(),
111+
# Canonicalize the bindings
112+
relax.transform.CanonicalizeBindings(),
113+
# Run default optimization pipeline
114+
relax.get_pipeline("zero"),
115+
# Tune the model and store the log to database
116+
relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS)
117+
if enable_tune
118+
else tvm.transform.Sequential([]),
119+
# Apply the database
120+
relax.transform.MetaScheduleApplyDatabase(work_dir),
121+
]
122+
)(mod)
123+
124+
# Only show the main function
125+
mod["main"].show()
126+
127+
######################################################################
128+
# Build and Deploy
129+
# ----------------
130+
# Finally, we build the optimized model and deploy it to the target device.
131+
132+
ex = relax.build(mod, target="cuda")
133+
dev = tvm.device("cuda", 0)
134+
vm = relax.VirtualMachine(ex, dev)
135+
# Need to allocate data and params on GPU device
136+
gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
137+
gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
138+
gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
139+
140+
print(gpu_out.shape)

0 commit comments

Comments
 (0)