1616# under the License.
1717# pylint: disable=missing-docstring
1818import logging
19+ from multiprocessing .sharedctypes import Value
1920import tempfile
2021from typing import List
21-
22+ from os import path as osp
2223import numpy as np
2324import pytest
2425import tvm
2526from tvm import relay
2627from tvm .contrib import graph_executor
2728from tvm .ir import IRModule
29+ from tvm .tir .schedule .schedule import Schedule
30+ from tvm .tir .schedule .trace import Trace
2831from tvm .meta_schedule import ReplayTraceConfig
29- from tvm .meta_schedule .database import PyDatabase , TuningRecord , Workload
32+ from tvm .meta_schedule .database import PyDatabase , TuningRecord , Workload , JSONDatabase
33+ from tvm .meta_schedule .integration import ApplyHistoryBest
3034from tvm .meta_schedule .testing .relay_workload import get_network
3135from tvm .meta_schedule .tune import tune_relay
3236from tvm .meta_schedule .utils import derived_object
3337from tvm .target .target import Target
38+ from tvm .script import tir as T
3439
3540logging .basicConfig ()
3641logging .getLogger ("tvm.meta_schedule" ).setLevel (logging .DEBUG )
3742
43+ # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
44+ # fmt: off
45+ @tvm .script .ir_module
46+ class tvmgen_default_fused_layout_transform :
47+ @T .prim_func
48+ def main (
49+ placeholder : T .Buffer [(1 , 3 , 16 , 16 ), "float32" ],
50+ T_layout_trans : T .Buffer [(1 , 1 , 16 , 16 , 3 ), "float32" ],
51+ ) -> None :
52+ # function attr dict
53+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
54+ # body
55+ # with T.block("root")
56+ for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 16 , 16 , 3 ):
57+ with T .block ("T_layout_trans" ):
58+ ax0 , ax1 , ax2 , ax3 , ax4 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
59+ T .reads (placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ])
60+ T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ])
61+ T_layout_trans [ax0 , ax1 , ax2 , ax3 , ax4 ] = T .if_then_else (
62+ ax0 < 1 and ax1 * 3 + ax4 < 3 and ax2 < 16 and ax3 < 16 ,
63+ placeholder [ax0 , ax1 * 3 + ax4 , ax2 , ax3 ],
64+ T .float32 (0 ),
65+ dtype = "float32" ,
66+ )
3867
39- @derived_object
40- class DummyDatabase (PyDatabase ):
41- def __init__ (self ):
42- super ().__init__ ()
43- self .records = []
44- self .workload_reg = []
45-
46- def has_workload (self , mod : IRModule ) -> Workload :
47- for workload in self .workload_reg :
48- if tvm .ir .structural_equal (workload .mod , mod ):
49- return True
50- return False
51-
52- def commit_tuning_record (self , record : TuningRecord ) -> None :
53- self .records .append (record )
54-
55- def commit_workload (self , mod : IRModule ) -> Workload :
56- for workload in self .workload_reg :
57- if tvm .ir .structural_equal (workload .mod , mod ):
58- return workload
59- workload = Workload (mod )
60- self .workload_reg .append (workload )
61- return workload
62-
63- def get_top_k (self , workload : Workload , top_k : int ) -> List [TuningRecord ]:
64- return list (
65- filter (
66- lambda x : x .workload == workload ,
67- sorted (self .records , key = lambda x : sum (x .run_secs ) / len (x .run_secs )),
68- )
69- )[: int (top_k )]
7068
71- def __len__ (self ) -> int :
72- return len (self .records )
69+ @tvm .script .ir_module
70+ class tvmgen_default_fused_nn_contrib_conv2d_NCHWc :
71+ @T .prim_func
72+ def main (placeholder : T .Buffer [(1 , 1 , 16 , 16 , 3 ), "float32" ], placeholder_1 : T .Buffer [(2 , 1 , 5 , 5 , 3 , 4 ), "float32" ], conv2d_NCHWc : T .Buffer [(1 , 2 , 16 , 16 , 4 ), "float32" ]) -> None :
73+ # function attr dict
74+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
75+ # body
76+ # with T.block("root")
77+ data_pad = T .alloc_buffer ([1 , 1 , 20 , 20 , 3 ], dtype = "float32" )
78+ for i0 , i1 , i2 , i3 , i4 in T .grid (1 , 1 , 20 , 20 , 3 ):
79+ with T .block ("data_pad" ):
80+ i0_1 , i1_1 , i2_1 , i3_1 , i4_1 = T .axis .remap ("SSSSS" , [i0 , i1 , i2 , i3 , i4 ])
81+ T .reads (placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ])
82+ T .writes (data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ])
83+ data_pad [i0_1 , i1_1 , i2_1 , i3_1 , i4_1 ] = T .if_then_else (2 <= i2_1 and i2_1 < 18 and 2 <= i3_1 and i3_1 < 18 , placeholder [i0_1 , i1_1 , i2_1 - 2 , i3_1 - 2 , i4_1 ], T .float32 (0 ), dtype = "float32" )
84+ for i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 in T .grid (1 , 2 , 16 , 16 , 4 , 3 , 5 , 5 ):
85+ with T .block ("conv2d_NCHWc" ):
86+ n , oc_chunk , oh , ow , oc_block , ic , kh , kw = T .axis .remap ("SSSSSRRR" , [i0 , i1 , i2 , i3 , i4 , i5 , i6 , i7 ])
87+ T .reads (conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ], data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ], placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ])
88+ T .writes (conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ])
89+ T .block_attr ({"workload" :["conv2d_NCHWc.x86" , ["TENSOR" , [1 , 1 , 16 , 16 , 3 ], "float32" ], ["TENSOR" , [2 , 1 , 5 , 5 , 3 , 4 ], "float32" ], [1 , 1 ], [2 , 2 , 2 , 2 ], [1 , 1 ], "NCHW3c" , "NCHW4c" , "float32" ]})
90+ with T .init ():
91+ conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = T .float32 (0 )
92+ conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] = conv2d_NCHWc [n , oc_chunk , oh , ow , oc_block ] + data_pad [n , ic // 3 , oh + kh , ow + kw , ic % 3 ] * placeholder_1 [oc_chunk , ic // 3 , kh , kw , ic % 3 , oc_block ]
93+
94+ @tvm .script .ir_module
95+ class tvmgen_default_fused_layout_transform_1 :
96+ @T .prim_func
97+ def main (placeholder : T .Buffer [(1 , 2 , 16 , 16 , 4 ), "float32" ], T_layout_trans : T .Buffer [(1 , 8 , 16 , 16 ), "float32" ]) -> None :
98+ # function attr dict
99+ T .func_attr ({"global_symbol" : "main" , "tir.noalias" : True })
100+ # body
101+ # with T.block("root")
102+ for i0 , i1 , i2 , i3 in T .grid (1 , 8 , 16 , 16 ):
103+ with T .block ("T_layout_trans" ):
104+ ax0 , ax1 , ax2 , ax3 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
105+ T .reads (placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ])
106+ T .writes (T_layout_trans [ax0 , ax1 , ax2 , ax3 ])
107+ T_layout_trans [ax0 , ax1 , ax2 , ax3 ] = T .if_then_else (ax0 < 1 and ax1 < 8 and ax2 < 16 and ax3 < 16 , placeholder [ax0 , ax1 // 4 , ax2 , ax3 , ax1 % 4 ], T .float32 (0 ), dtype = "float32" )
73108
74- def print_results ( self ) -> None :
75- print ( " \n " . join ([ str ( r ) for r in self . records ]))
109+ # fmt: on
110+ # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no- self-argument
76111
77112
78113@pytest .mark .skip ("Integration test" )
@@ -101,8 +136,7 @@ def test_meta_schedule_tune_relay(
101136 mod , params , (input_name , _ , _ ) = get_network (name = model_name , input_shape = input_shape )
102137 target = Target (target )
103138 with tempfile .TemporaryDirectory () as work_dir :
104- database = DummyDatabase ()
105- rt_mod : tvm .runtime .Module = tune_relay (
139+ rt_mod1 : tvm .runtime .Module = tune_relay (
106140 mod = mod ,
107141 params = params ,
108142 target = target ,
@@ -111,11 +145,173 @@ def test_meta_schedule_tune_relay(
111145 num_trials_total = 32 ,
112146 ),
113147 work_dir = work_dir ,
114- database = database ,
148+ database = JSONDatabase (
149+ osp .join (work_dir , "workload.json" ), osp .join (work_dir , "records.json" )
150+ ),
151+ )
152+ # Compile without meta-scheduler for correctness check
153+ with tvm .transform .PassContext (opt_level = 0 ):
154+ rt_mod2 = relay .build (mod , target = target , params = params )
155+
156+ def get_output (data , lib ):
157+ module = graph_executor .GraphModule (lib ["default" ](dev ))
158+ module .set_input (input_name , data )
159+ module .run ()
160+ return module .get_output (0 ).numpy ()
161+
162+ # Check correctness
163+ actual_output = get_output (data , rt_mod1 )
164+ expected_output = get_output (data , rt_mod2 )
165+ assert np .allclose (actual_output , expected_output , rtol = 1e-4 , atol = 2e-4 )
166+
167+
168+ def test_meta_schedule_te2primfunc_argument_order ():
169+ @derived_object
170+ class TestDummyDatabase (PyDatabase ):
171+ def __init__ (self ):
172+ super ().__init__ ()
173+ self .records = []
174+ self .workload_reg = []
175+
176+ def has_workload (self , mod : IRModule ) -> Workload :
177+ for workload in self .workload_reg :
178+ if tvm .ir .structural_equal (workload .mod , mod ):
179+ return True
180+ # The database has already put in all correct workloads
181+ raise ValueError (
182+ "The workload searched for is not in given database!"
183+ + " Incorrect TIR was generated from TE subgraph."
184+ )
185+
186+ def commit_tuning_record (self , record : TuningRecord ) -> None :
187+ self .records .append (record )
188+
189+ def commit_workload (self , mod : IRModule ) -> Workload :
190+ for workload in self .workload_reg :
191+ if tvm .ir .structural_equal (workload .mod , mod ):
192+ return workload
193+ workload = Workload (mod )
194+ self .workload_reg .append (workload )
195+ return workload
196+
197+ def get_top_k (self , workload : Workload , top_k : int ) -> List [TuningRecord ]:
198+ return list (
199+ filter (
200+ lambda x : x .workload == workload ,
201+ sorted (self .records , key = lambda x : sum (x .run_secs ) / len (x .run_secs )),
202+ )
203+ )[: int (top_k )]
204+
205+ def __len__ (self ) -> int :
206+ return len (self .records )
207+
208+ def print_results (self ) -> None :
209+ print ("\n " .join ([str (r ) for r in self .records ]))
210+
211+ data_shape = (1 , 3 , 16 , 16 )
212+ weight_shape = (8 , 3 , 5 , 5 )
213+ data = relay .var ("data" , relay .TensorType (data_shape , "float32" ))
214+ weight = relay .var ("weight" , relay .TensorType (weight_shape , "float32" ))
215+ y = relay .nn .conv2d (
216+ data ,
217+ weight ,
218+ padding = (2 , 2 ),
219+ kernel_size = (5 , 5 ),
220+ kernel_layout = "OIHW" ,
221+ out_dtype = "float32" ,
222+ )
223+ f = relay .Function ([data , weight ], y )
224+ mod = tvm .IRModule .from_expr (f )
225+ mod = relay .transform .InferType ()(mod )
226+
227+ data_sample = np .random .rand (* data_shape ).astype ("float32" )
228+ weight_sample = np .random .rand (* weight_shape ).astype ("float32" )
229+ params = {mod ["main" ].params [1 ].name_hint : weight_sample }
230+
231+ input_name = "data"
232+ dev = tvm .cpu ()
233+ target = Target ("llvm --num-cores=16" )
234+ data = tvm .nd .array (data_sample , dev )
235+
236+ database = TestDummyDatabase ()
237+ database .commit_workload (tvmgen_default_fused_layout_transform )
238+ database .commit_workload (tvmgen_default_fused_layout_transform_1 )
239+ database .commit_workload (tvmgen_default_fused_nn_contrib_conv2d_NCHWc )
240+
241+ with ApplyHistoryBest (database ):
242+ with tvm .transform .PassContext (
243+ opt_level = 3 ,
244+ config = {"relay.backend.use_meta_schedule" : True },
245+ ):
246+ rt_mod1 = relay .build (mod , target = target , params = params )
247+
248+ # Compile without meta-scheduler for correctness check
249+ with tvm .transform .PassContext (opt_level = 0 ):
250+ rt_mod2 = relay .build (mod , target = target , params = params )
251+
252+ def get_output (data , lib ):
253+ module = graph_executor .GraphModule (lib ["default" ](dev ))
254+ module .set_input (input_name , data )
255+ module .run ()
256+ return module .get_output (0 ).numpy ()
257+
258+ # Check correctness
259+ actual_output = get_output (data , rt_mod1 )
260+ expected_output = get_output (data , rt_mod2 )
261+ assert np .allclose (actual_output , expected_output , rtol = 1e-4 , atol = 2e-4 )
262+
263+
264+ def test_meta_schedule_relay_lowering ():
265+ data_shape = (1 , 3 , 16 , 16 )
266+ weight_shape = (8 , 3 , 5 , 5 )
267+ data = relay .var ("data" , relay .TensorType (data_shape , "float32" ))
268+ weight = relay .var ("weight" , relay .TensorType (weight_shape , "float32" ))
269+ y = relay .nn .conv2d (
270+ data ,
271+ weight ,
272+ padding = (2 , 2 ),
273+ kernel_size = (5 , 5 ),
274+ kernel_layout = "OIHW" ,
275+ out_dtype = "float32" ,
276+ )
277+ f = relay .Function ([data , weight ], y )
278+ mod = tvm .IRModule .from_expr (f )
279+ mod = relay .transform .InferType ()(mod )
280+
281+ data_sample = np .random .rand (* data_shape ).astype ("float32" )
282+ weight_sample = np .random .rand (* weight_shape ).astype ("float32" )
283+ params = {mod ["main" ].params [1 ].name_hint : weight_sample }
284+
285+ input_name = "data"
286+ dev = tvm .cpu ()
287+ target = Target ("llvm --num-cores=16" )
288+ data = tvm .nd .array (data_sample , dev )
289+
290+ with tempfile .TemporaryDirectory () as work_dir :
291+ database = JSONDatabase (
292+ osp .join (work_dir , "workload.json" ), osp .join (work_dir , "records.json" )
115293 )
294+
295+ database .commit_tuning_record (
296+ TuningRecord (
297+ Trace ([], {}),
298+ [0.0 ],
299+ database .commit_workload (tvmgen_default_fused_nn_contrib_conv2d_NCHWc ),
300+ target = target ,
301+ args_info = [],
302+ )
303+ )
304+
305+ with ApplyHistoryBest (database ):
306+ with tvm .transform .PassContext (
307+ opt_level = 3 ,
308+ config = {"relay.backend.use_meta_schedule" : True },
309+ ):
310+ rt_mod1 = relay .build (mod , target = target , params = params )
311+
116312 # Compile without meta-scheduler for correctness check
117313 with tvm .transform .PassContext (opt_level = 0 ):
118- rt_mod2 = relay .build (mod , target = Target ( "llvm" ) , params = params )
314+ rt_mod2 = relay .build (mod , target = target , params = params )
119315
120316 def get_output (data , lib ):
121317 module = graph_executor .GraphModule (lib ["default" ](dev ))
@@ -124,8 +320,8 @@ def get_output(data, lib):
124320 return module .get_output (0 ).numpy ()
125321
126322 # Check correctness
127- actual_output = get_output (data , rt_mod )
128- expected_output = get_output (tvm . nd . array ( data . numpy (), device = tvm . cpu ()) , rt_mod2 )
323+ actual_output = get_output (data , rt_mod1 )
324+ expected_output = get_output (data , rt_mod2 )
129325 assert np .allclose (actual_output , expected_output , rtol = 1e-4 , atol = 2e-4 )
130326
131327
@@ -136,3 +332,5 @@ def get_output(data, lib):
136332 test_meta_schedule_tune_relay ("mobilenet_v2" , [1 , 3 , 224 , 224 ], "nvidia/geforce-rtx-3070" )
137333 test_meta_schedule_tune_relay ("bert_base" , [1 , 64 ], "llvm --num-cores=16" )
138334 test_meta_schedule_tune_relay ("bert_base" , [1 , 64 ], "nvidia/geforce-rtx-3070" )
335+ test_meta_schedule_te2primfunc_argument_order ()
336+ test_meta_schedule_relay_lowering ()
0 commit comments