Skip to content

Commit ecd3c88

Browse files
authored
[Runtime][PipelineExecutor] Tutorial of using pipeline executor. (#11557)
* [Runtime][PipelineExecutor] Tutorial of using pipeline executor. Tutorial of using pipeline executor including the byoc use case. * fix ci issue * document change. * triger build * fix doc issue * fix ci issue * doc issue * fix ci issue * fix ci issue. * fix __file__ not found problem. this is a known issue of sphinx-gallery sphinx-gallery/sphinx-gallery#211 * fix byoc with dnnl issue * enable dnnl and pipeline executor * trigger build * trigger build * fix build issue * trigger build * oneflow cause crash, do test with change * add sphinx skip * plint * remove from_oneflow change test. * remove pipeline executor change for test * plint * enable DNNL and pipeline * disable DNNL * enable DNNL without pipeline * remove dnnl and add cutlass * use cutlass with byoc * change into cutlass * fix doc convention issue * remove duplicate variable * fix plint issue. * address review comments. * address review comments * fix bug. * polish the document * fix plint issue * address review comments. * address review comments * address review comments
1 parent 9863cf0 commit ecd3c88

File tree

4 files changed

+281
-9
lines changed

4 files changed

+281
-9
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Using Pipeline Executor in Relay
19+
=================================
20+
**Author**: `Hua Jiang <https://https://github.com/huajsj>`_
21+
22+
This is a short tutorial on how to use "Pipeline Executor" with Relay.
23+
"""
24+
import tvm
25+
from tvm import te
26+
import numpy as np
27+
from tvm.contrib import graph_executor as runtime
28+
from tvm.relay.op.contrib.cutlass import partition_for_cutlass
29+
from tvm import relay
30+
from tvm.relay import testing
31+
import tvm.testing
32+
from tvm.contrib.cutlass import (
33+
has_cutlass,
34+
num_cutlass_partitions,
35+
finalize_modules,
36+
finalize_modules_vm,
37+
)
38+
39+
img_size = 8
40+
#######################################################################
41+
# Create a simple network, this network can be a pre-trained model too.
42+
# ---------------------------------------------------------------------
43+
# Let's create a very simple network for demonstration.
44+
# It consists of convolution, batch normalization, dense, and ReLU activation.
45+
def get_network():
46+
out_channels = 16
47+
batch_size = 1
48+
data = relay.var("data", relay.TensorType((batch_size, 3, img_size, img_size), "float16"))
49+
dense_weight = relay.var(
50+
"dweight", relay.TensorType((batch_size, 16 * img_size * img_size), "float16")
51+
)
52+
weight = relay.var("weight")
53+
second_weight = relay.var("second_weight")
54+
bn_gamma = relay.var("bn_gamma")
55+
bn_beta = relay.var("bn_beta")
56+
bn_mmean = relay.var("bn_mean")
57+
bn_mvar = relay.var("bn_var")
58+
simple_net = relay.nn.conv2d(
59+
data=data, weight=weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1)
60+
)
61+
simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
62+
simple_net = relay.nn.relu(simple_net)
63+
simple_net = relay.nn.batch_flatten(simple_net)
64+
simple_net = relay.nn.dense(simple_net, dense_weight)
65+
simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)
66+
data_shape = (batch_size, 3, img_size, img_size)
67+
net, params = testing.create_workload(simple_net)
68+
return net, params, data_shape
69+
70+
71+
net, params, data_shape = get_network()
72+
###########################################
73+
# Splitting the network into two subgraphs.
74+
# -----------------------------------------
75+
# This function called 'graph_split' from a unit test is just an example. User can create a customized logic
76+
# to split the graph.
77+
import inspect
78+
import os
79+
80+
tutorial_dir = os.path.dirname(inspect.getfile(lambda: None))
81+
os.sys.path.append(os.path.join(tutorial_dir, "../../../tests/python/relay"))
82+
from test_pipeline_executor import graph_split
83+
84+
###########################################
85+
# Splitting the network into two subgraphs.
86+
split_config = [{"op_name": "nn.relu", "op_index": 0}]
87+
subgraphs = graph_split(net["main"], split_config, params)
88+
###########################################################
89+
# The generated subgraphs should look something like below.
90+
91+
"""
92+
#subgraphs[0])
93+
94+
def @main(%data: Tensor[(1, 3, img_size, img_size), float16]) {
95+
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float16] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, img_size, img_size), float16] */;
96+
%1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float16] */, meta[relay.Constant][2] /* ty=Tensor[(16), float16]*/, meta[relay.Constant][3] /* ty=Tensor[(16), float16] */, meta[relay.Constant][4] /* ty=Tensor[(16), float16] */) /* ty=(Tensor[(1,16, img_size, img_size), float16], Tensor[(16), float16], Tensor[(16), float16]) */;
97+
%2 = %1.0;
98+
nn.relu(%2) /* ty=Tensor[(1, 16, img_size, img_size), float16] */
99+
}
100+
101+
#subgraphs[1]
102+
103+
def @main(%data_n_0: Tensor[(1, 16, 8, 8), float16] /* ty=Tensor[(1, 16, 8, 8), float16] */) {
104+
%0 = nn.batch_flatten(%data_n_0) /* ty=Tensor[(1, 1024), float16] */;
105+
nn.dense(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 1024), float16] */, units=None) /* ty=Tensor[(1, 1), float16] */
106+
}
107+
108+
"""
109+
110+
# sphinx_gallery_start_ignore
111+
from tvm import testing
112+
113+
testing.utils.install_request_hook(depth=3)
114+
# sphinx_gallery_end_ignore
115+
116+
#########################################
117+
# Build the subgraph with cutlass target.
118+
# ---------------------------------------
119+
120+
cutlass = tvm.target.Target(
121+
{
122+
"kind": "cutlass",
123+
"sm": int(tvm.target.Target("cuda").arch.split("_")[1]),
124+
"use_3xtf32": True,
125+
"split_k_slices": [1],
126+
"profile_all_alignments": False,
127+
"find_first_valid": True,
128+
"use_multiprocessing": True,
129+
"use_fast_math": False,
130+
"tmp_dir": "./tmp",
131+
},
132+
host=tvm.target.Target("llvm"),
133+
)
134+
135+
136+
def cutlass_build(mod, target, params=None, target_host=None, mod_name="default"):
137+
target = [target, cutlass]
138+
lib = relay.build_module.build(
139+
mod, target=target, params=params, target_host=target_host, mod_name=mod_name
140+
)
141+
return lib
142+
143+
144+
###########################################################
145+
# Run the two subgraphs in pipeline with pipeline executor.
146+
# ---------------------------------------------------------
147+
# Set 'USE_PIPELINE_EXECUTOR' as ON, and set USE_CUTLASS' as ON in cmake.
148+
from tvm.contrib import graph_executor, pipeline_executor, pipeline_executor_build
149+
150+
#########################################
151+
# Create subgraph pipeline configuration.
152+
# Associate a subgraph module with a target.
153+
# Use CUTLASS BYOC to build the second subgraph module.
154+
mod0, mod1 = subgraphs[0], subgraphs[1]
155+
# Use cutlass as the codegen.
156+
mod1 = partition_for_cutlass(mod1)
157+
#################################################
158+
# Get the pipeline executor configuration object.
159+
pipe_config = pipeline_executor_build.PipelineConfig()
160+
###########################################################################
161+
# Set the compile target of the subgraph module.
162+
pipe_config[mod0].target = "llvm"
163+
pipe_config[mod0].dev = tvm.cpu(0)
164+
##############################################################
165+
# Set the compile target of the second subgraph module as cuda.
166+
pipe_config[mod1].target = "cuda"
167+
pipe_config[mod1].dev = tvm.device("cuda", 0)
168+
pipe_config[mod1].build_func = cutlass_build
169+
pipe_config[mod1].export_cc = "nvcc"
170+
# Create the pipeline by connecting the subgraph modules.
171+
# The global input will be forwarded to the input interface of the first module named mod0
172+
pipe_config["input"]["data"].connect(pipe_config[mod0]["input"]["data"])
173+
# The first output of mod0 will be forwarded to the input interface of mod1
174+
pipe_config[mod0]["output"][0].connect(pipe_config[mod1]["input"]["data_n_0"])
175+
# The first output of mod1 will be the first global output.
176+
pipe_config[mod1]["output"][0].connect(pipe_config["output"][0])
177+
######################################
178+
# The pipeline configuration as below.
179+
"""
180+
print(pipe_config)
181+
Inputs
182+
|data: mod0:data
183+
184+
output
185+
|output(0) : mod1.output(0)
186+
187+
connections
188+
|mod0.output(0)-> mod1.data_n_0
189+
"""
190+
191+
# sphinx_gallery_start_ignore
192+
from tvm import testing
193+
194+
# testing.utils.install_request_hook(depth=3)
195+
# sphinx_gallery_end_ignore
196+
##############################
197+
# Build the pipeline executor.
198+
# ----------------------------
199+
with tvm.transform.PassContext(opt_level=3):
200+
pipeline_mod_factory = pipeline_executor_build.build(pipe_config)
201+
###############################################
202+
# Export the parameter configuration to a file.
203+
directory_path = tvm.contrib.utils.tempdir().temp_dir
204+
os.makedirs(directory_path, exist_ok=True)
205+
config_file_name = pipeline_mod_factory.export_library(directory_path)
206+
################################################################
207+
# Use the load function to create and initialize PipelineModule.
208+
# --------------------------------------------------------------
209+
pipeline_module = pipeline_executor.PipelineModule.load_library(config_file_name)
210+
211+
############################
212+
# Run the pipeline executor.
213+
# --------------------------
214+
# Allocate input data.
215+
data = np.random.uniform(-1, 1, size=data_shape).astype("float16")
216+
pipeline_module.set_input("data", tvm.nd.array(data))
217+
##########################################################################
218+
# Run the two subgraph in the pipeline mode to get the output asynchronously
219+
# or synchronously. In the following example, it is synchronous.
220+
pipeline_module.run()
221+
outputs = pipeline_module.get_output()
222+
######################################
223+
# Use graph_executor for verification.
224+
# ------------------------------------
225+
# Run these two subgraphs in sequence with graph_executor to get the output.
226+
target = "llvm"
227+
dev0 = tvm.device(target, 0)
228+
lib0 = relay.build_module.build(mod0, target, params=params)
229+
module0 = runtime.GraphModule(lib0["default"](dev0))
230+
cuda = tvm.target.Target("cuda", host=tvm.target.Target("llvm"))
231+
lib1 = relay.build_module.build(mod1, [cuda, cutlass], params=params)
232+
lib1 = finalize_modules(lib1, "compile.so", "./tmp")
233+
234+
dev1 = tvm.device("cuda", 0)
235+
236+
module1 = runtime.GraphModule(lib1["default"](dev1))
237+
238+
module0.set_input("data", data)
239+
module0.run()
240+
out_shape = (1, 16, img_size, img_size)
241+
out = module0.get_output(0, tvm.nd.empty(out_shape, "float16"))
242+
module1.set_input("data_n_0", out)
243+
module1.run()
244+
out_shape = (1, 1)
245+
out = module1.get_output(0, tvm.nd.empty(out_shape, "float16"))
246+
####################
247+
# Verify the result.
248+
tvm.testing.assert_allclose(outputs[0].numpy(), out.numpy())

python/tvm/contrib/pipeline_executor.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Pipeline executor that executes a series of modules in a pipeline fashion."""
1818
import json
1919
import os
20+
import time
2021
from tvm import runtime
2122
from tvm._ffi import get_global_func
2223
from tvm.contrib import graph_executor
@@ -131,14 +132,26 @@ def get_input(self, key):
131132
"""
132133
return self._get_input(key)
133134

134-
def get_output(self):
135+
def get_output(self, synchronize=True, sleep_interval=0.001):
135136
"""Get the output.
136137
Returns
137138
-------
138139
data : Array[NDArray]
139140
A list of output data.
141+
synchronize : BOOL
142+
Whether to do a synchronize poll.
143+
sleep_interval : Float32
144+
When doing the synchronize loop poll, how many seconds the loop should sleep for yield.
140145
"""
141-
return self._get_output()
146+
outputs = []
147+
if not synchronize:
148+
outputs = self._get_output()
149+
else:
150+
while not outputs:
151+
outputs = self._get_output()
152+
time.sleep(sleep_interval)
153+
154+
return outputs
142155

143156
@property
144157
def num_executing_pipeline(self):
@@ -302,11 +315,16 @@ def export_library(self, directory_path):
302315
self.pipeline_mods[lib_index]["dev"].device_type,
303316
self.pipeline_mods[lib_index]["dev"].device_id,
304317
)
305-
306318
# Get the graph, lib, and parameters from GraphExecutorFactoryModule.
307319
lib = self.pipeline_mods[lib_index]["lib"]
308320
# Export the lib, graph, and parameters to disk.
309-
lib.export_library(mconfig["lib_name"])
321+
if self.pipeline_mods[lib_index]["export_cc"]:
322+
lib.export_library(
323+
mconfig["lib_name"], cc=self.pipeline_mods[lib_index]["export_cc"]
324+
)
325+
else:
326+
lib.export_library(mconfig["lib_name"])
327+
310328
with open(mconfig["json_name"], "w") as file_handle:
311329
file_handle.write(lib.graph_json)
312330
with open(mconfig["params_name"], "wb") as file_handle:

python/tvm/contrib/pipeline_executor_build.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def build(pipe_configs):
8686
# Use "mod_idx" as the key to create a "module_connection" map which is not only
8787
# for the module index but also for the module connection used to build the pipeline.
8888
module_string_config[mod_idx] = pipe_config
89-
libs[mod_idx] = {"lib": lib, "dev": dev, "fcompile": mod_config["fcompile"]}
89+
libs[mod_idx] = {
90+
"lib": lib,
91+
"dev": dev,
92+
"fcompile": mod_config["fcompile"],
93+
"export_cc": mod_config["export_cc"],
94+
}
9095

9196
# Creating a text form configuration to record the "input_connection" and the
9297
# "module_connection" information. The "input_connection" is used to record the
@@ -132,10 +137,7 @@ def export_library(factory, directory_path):
132137
mconfig["json_name"] = "{}/json{}".format(directory_path, lib_index)
133138
mconfig["params_name"] = "{}/params{}".format(directory_path, lib_index)
134139
lib_config = factory.pipeline_mods[lib_index]
135-
mconfig["dev"] = "{},{}".format(
136-
lib_config["dev"].device_type,
137-
lib_config["dev"].device_id,
138-
)
140+
mconfig["dev"] = "{},{}".format(lib_config["dev"].device_type, lib_config["dev"].device_id)
139141
fcompile = lib_config["fcompile"]
140142
if not fcompile:
141143
fcompile = False
@@ -413,6 +415,7 @@ def __init__(self, mod=None):
413415
self.fcompile = None
414416
self.name = None
415417
self.dev = None
418+
self.export_cc = None
416419
self.cpu_affinity = ""
417420
self.idx = None
418421
self.mod = mod
@@ -601,6 +604,7 @@ def get_config(self):
601604
"target": module.target,
602605
"fcompile": module.fcompile,
603606
"dev": module.dev,
607+
"export_cc": module.export_cc,
604608
}
605609

606610
# Creating a map including pipeline inputs and subgraph inputs.

tests/scripts/task_config_build_gpu.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,5 @@ echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake
4747
echo set\(USE_CCACHE OFF\) >> config.cmake
4848
echo set\(SUMMARIZE ON\) >> config.cmake
4949
echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
50+
echo set\(USE_PIPELINE_EXECUTOR ON\) >> config.cmake
51+
echo set\(USE_CUTLASS ON\) >> config.cmake

0 commit comments

Comments
 (0)