Skip to content

Commit 63efbad

Browse files
committed
use cutlass with byoc
1 parent 4c80999 commit 63efbad

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

gallery/how_to/work_with_relay/using_with_pipeline_executor.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,19 @@
2525
from tvm import te
2626
import numpy as np
2727
from tvm.contrib import graph_executor as runtime
28+
from tvm.relay.op.contrib.cutlass import partition_for_cutlass
2829
from tvm import relay
2930
from tvm.relay import testing
3031
import tvm.testing
3132
import time
32-
33+
from tvm.contrib.cutlass import (
34+
has_cutlass,
35+
num_cutlass_partitions,
36+
finalize_modules,
37+
finalize_modules_vm,
38+
)
39+
40+
img_size = 8
3341
#######################################################################
3442
# Create a simple network, this network can be a pre-trained model too.
3543
# ---------------------------------------------------------------------
@@ -38,7 +46,10 @@
3846
def get_network():
3947
out_channels = 16
4048
batch_size = 1
41-
data = relay.var("data", relay.TensorType((batch_size, 3, 224, 224), "float32"))
49+
data = relay.var("data", relay.TensorType((batch_size, 3, img_size, img_size), "float32"))
50+
dense_weight = relay.var(
51+
"data", relay.TensorType((batch_size, 16 * img_size * img_size), "float32")
52+
)
4253
weight = relay.var("weight")
4354
second_weight = relay.var("second_weight")
4455
bn_gamma = relay.var("bn_gamma")
@@ -50,15 +61,10 @@ def get_network():
5061
)
5162
simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
5263
simple_net = relay.nn.relu(simple_net)
53-
simple_net = relay.nn.conv2d(
54-
data=simple_net,
55-
weight=second_weight,
56-
kernel_size=(3, 3),
57-
channels=out_channels,
58-
padding=(1, 1),
59-
)
64+
simple_net = relay.nn.batch_flatten(simple_net)
65+
simple_net = relay.nn.dense(simple_net, dense_weight)
6066
simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)
61-
data_shape = (batch_size, 3, 224, 224)
67+
data_shape = (batch_size, 3, img_size, img_size)
6268
net, params = testing.create_workload(simple_net)
6369
return net, params, data_shape
6470

@@ -86,19 +92,19 @@ def get_network():
8692
"""
8793
#subgraphs[0])
8894
89-
def @main(%data: Tensor[(1, 3, 224, 224), float32]) {
90-
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
91-
%1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] */, meta[relay.Constant][2] /* ty=Tensor[(16), float32]*/, meta[relay.Constant][3] /* ty=Tensor[(16), float32] */, meta[relay.Constant][4] /* ty=Tensor[(16), float32] */) /* ty=(Tensor[(1,16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
95+
def @main(%data: Tensor[(1, 3, img_size, img_size), float32]) {
96+
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, img_size, img_size), float32] */;
97+
%1 = nn.batch_norm(%0, meta[relay.Constant][1] /* ty=Tensor[(16), float32] */, meta[relay.Constant][2] /* ty=Tensor[(16), float32]*/, meta[relay.Constant][3] /* ty=Tensor[(16), float32] */, meta[relay.Constant][4] /* ty=Tensor[(16), float32] */) /* ty=(Tensor[(1,16, img_size, img_size), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
9298
%2 = %1.0;
93-
nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
99+
nn.relu(%2) /* ty=Tensor[(1, 16, img_size, img_size), float32] */
94100
}
95101
96102
peline-tutorial
97103
98104
#subgraphs[1]
99105
100-
def @main(%data_n_0: Tensor[(1, 16, 224, 224), float32]) {
101-
nn.conv2d(%data_n_0, meta[relay.Constant][0] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */
106+
def @main(%data_n_0: Tensor[(1, 16, img_size, img_size), float32]) {
107+
nn.conv2d(%data_n_0, meta[relay.Constant][0] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, img_size, img_size), float32] */
102108
}
103109
"""
104110

@@ -123,9 +129,11 @@ def run_pipeline():
123129
# Using BYOC to set the codegen of the second subgraph module.
124130
# To use dnnl the 'USE_DNNL_CODEGEN' should set as ON in config.cmake and installing MKL-DNN.
125131
mod0, mod1 = subgraphs[0], subgraphs[1]
126-
mod0 = relay.transform.AnnotateTarget(["dnnl"])(mod0)
127-
mod0 = relay.transform.MergeCompilerRegions()(mod0)
128-
mod0 = relay.transform.PartitionGraph()(mod0)
132+
# mod0 = relay.transform.AnnotateTarget(["dnnl"])(mod0)
133+
# mod0 = relay.transform.AnnotateTarget(["cutlass"])(mod0)
134+
# mod0 = relay.transform.MergeCompilerRegions()(mod0)
135+
# mod0 = relay.transform.PartitionGraph()(mod0)
136+
mod1 = partition_for_cutlass(mod1)
129137
#################################################
130138
# Get the pipeline executor configuration object.
131139
pipe_config = pipeline_executor_build.PipelineConfig()
@@ -138,8 +146,8 @@ def run_pipeline():
138146
pipe_config[mod1].cpu_affinity = "0"
139147
##############################################################
140148
# Set the compile target of the second subgraph module as LLVM.
141-
pipe_config[mod1].target = "llvm"
142-
pipe_config[mod1].dev = tvm.cpu(0)
149+
pipe_config[mod1].target = "cuda"
150+
pipe_config[mod1].dev = tvm.device("cuda", 0)
143151
#################################################################################
144152
# Set the cpu afinity for control flow, for example using cpu 1 for control flow.
145153
pipe_config[mod1].cpu_affinity = "1"
@@ -208,11 +216,14 @@ def run_pipeline():
208216
module1 = runtime.GraphModule(lib1["default"](dev))
209217
module0.set_input("data", data)
210218
module0.run()
211-
out_shape = (1, 16, 224, 224)
219+
out_shape = (1, 16, img_size, img_size)
212220
out = module0.get_output(0, tvm.nd.empty(out_shape))
213221
module1.set_input("data_n_0", out)
214222
module1.run()
215223
out = module1.get_output(0, tvm.nd.empty(out_shape))
216224
####################
217225
# Verify the result.
218226
tvm.testing.assert_allclose(outputs[0].numpy(), out.numpy())
227+
228+
229+
run_pipeline()

0 commit comments

Comments
 (0)