@@ -114,5 +114,74 @@ def test_bf16_storage_legalize():
114114 tvm .ir .assert_structural_equal (after , expected )
115115
116116
117+ def test_bf16_storage_scope ():
118+ def get_before ():
119+ @tvm .script .ir_module
120+ class Before :
121+ @T .prim_func
122+ def main (
123+ Aptr : T .handle ("bfloat16" , storage_scope = "shared" ),
124+ Bptr : T .handle ("bfloat16" , storage_scope = "local" ),
125+ Dptr : T .handle ("bfloat16" ),
126+ ):
127+ T .func_attr ({"global_symbol" : "main" })
128+ A = T .decl_buffer ((100 ,), "bfloat16" , data = Aptr )
129+ B = T .decl_buffer ((100 ,), "bfloat16" , data = Bptr )
130+ D = T .decl_buffer ((100 ,), "bfloat16" , data = Dptr )
131+ C = T .decl_buffer ((100 ,), "bfloat16" )
132+ for i in T .grid (100 ):
133+ C [i ] = A [i ] + B [i ]
134+ D [i ] = T .exp (C [i ])
135+
136+ return Before
137+
138+ def after_compute_legalize ():
139+ @tvm .script .ir_module
140+ class After :
141+ @T .prim_func
142+ def main (
143+ Aptr : T .handle ("bfloat16" , storage_scope = "shared" ),
144+ Bptr : T .handle ("bfloat16" , storage_scope = "local" ),
145+ Dptr : T .handle ("bfloat16" ),
146+ ):
147+ T .func_attr ({"global_symbol" : "main" })
148+ A = T .decl_buffer ((100 ,), "bfloat16" , data = Aptr )
149+ B = T .decl_buffer ((100 ,), "bfloat16" , data = Bptr )
150+ D = T .decl_buffer ((100 ,), "bfloat16" , data = Dptr )
151+ C = T .decl_buffer ((100 ,), "float32" )
152+ for i in T .grid (100 ):
153+ C [i ] = bf16tof32 (A [i ]) + bf16tof32 (B [i ])
154+ D [i ] = f32tobf16 (T .exp (C [i ]))
155+
156+ return After
157+
158+ def after_storage_legalize ():
159+ @tvm .script .ir_module
160+ class After :
161+ @T .prim_func
162+ def main (
163+ Aptr : T .handle ("uint16" , storage_scope = "shared" ),
164+ Bptr : T .handle ("uint16" , storage_scope = "local" ),
165+ Dptr : T .handle ("uint16" ),
166+ ):
167+ T .func_attr ({"global_symbol" : "main" })
168+ A = T .decl_buffer ((100 ,), "uint16" , data = Aptr )
169+ B = T .decl_buffer ((100 ,), "uint16" , data = Bptr )
170+ D = T .decl_buffer ((100 ,), "uint16" , data = Dptr )
171+ C = T .decl_buffer ((100 ,), "float32" )
172+ for i in T .grid (100 ):
173+ C [i ] = u16tof32 (A [i ]) + u16tof32 (B [i ])
174+ D [i ] = f32tou16 (T .exp (C [i ]))
175+
176+ return After
177+
178+ before = get_before ()
179+ after_compute = tvm .tir .transform .BF16ComputeLegalize ()(before )
180+ after_storage = tvm .tir .transform .BF16StorageLegalize ()(after_compute )
181+ tvm .ir .assert_structural_equal (after_compute , after_compute_legalize ())
182+ tvm .ir .assert_structural_equal (after_storage , after_storage_legalize ())
183+
184+
117185if __name__ == "__main__" :
118186 test_bf16_storage_legalize ()
187+ test_bf16_storage_scope ()
0 commit comments