@@ -488,6 +488,19 @@ def cache_read_nested_seq_target(
488488 C [vi , vj ] = A_global [vi , vj ] * T .float32 (2 )
489489
490490
491+ @T .prim_func
492+ def nested_buffer_access (var_A : T .handle , var_B : T .handle , var_C : T .handle ):
493+ A = T .match_buffer (var_A , (T .int64 (7 ), T .int64 (512 )), dtype = "float32" )
494+ B = T .match_buffer (var_B , T .int64 (1 ), dtype = "int32" )
495+ C = T .match_buffer (var_C , (T .int64 (1 ), T .int64 (512 )), dtype = "float32" )
496+ for ax0 , ax1 in T .grid (T .int64 (1 ), T .int64 (512 )):
497+ with T .block ("C" ):
498+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
499+ T .reads (A [B [v_ax0 ], v_ax1 ], B [v_ax0 ])
500+ T .writes (C [v_ax0 , v_ax1 ])
501+ C [v_ax0 , v_ax1 ] = A [B [v_ax0 ], v_ax1 ]
502+
503+
491504########## Expected function after cache_read ##########
492505
493506
@@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None:
831844 data_io [v0 ] = data_io_global_1 [v0 ]
832845
833846
847+ @T .prim_func
848+ def cache_read_nested_buffer_access (var_A : T .handle , var_B : T .handle , var_C : T .handle ):
849+ A = T .match_buffer (var_A , (T .int64 (7 ), T .int64 (512 )), dtype = "float32" )
850+ B = T .match_buffer (var_B , T .int64 (1 ), dtype = "int32" )
851+ C = T .match_buffer (var_C , (T .int64 (1 ), T .int64 (512 )), dtype = "float32" )
852+ B_global = T .alloc_buffer ((T .int64 (1 ),), "int32" )
853+ for ax0 in range (T .int64 (1 )):
854+ with T .block ("B_global" ):
855+ v0 = T .axis .spatial (T .int64 (1 ), ax0 )
856+ T .reads (B [v0 ])
857+ T .writes (B_global [v0 ])
858+ B_global [v0 ] = B [v0 ]
859+ for ax0 , ax1 in T .grid (T .int64 (1 ), T .int64 (512 )):
860+ with T .block ("C" ):
861+ v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
862+ T .reads (A [B_global [v_ax0 ], v_ax1 ], B_global [v_ax0 ])
863+ T .writes (C [v_ax0 , v_ax1 ])
864+ C [v_ax0 , v_ax1 ] = A [B_global [v_ax0 ], v_ax1 ]
865+
866+
834867########## Expected function after cache_write ##########
835868
836869
@@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name):
13581391 verify_trace_roundtrip (sch = sch , mod = elementwise_shape_int64 )
13591392
13601393
1394+ def test_cache_read_nested_buffer_access (use_block_name ):
1395+ sch = tir .Schedule (nested_buffer_access , debug_mask = "all" )
1396+ block_c = "C" if use_block_name else sch .get_block ("C" )
1397+ sch .cache_read (block_c , 1 , "global" )
1398+ assert_structural_equal_ignore_global_symbol (cache_read_nested_buffer_access , sch .mod ["main" ])
1399+ verify_trace_roundtrip (sch = sch , mod = nested_buffer_access )
1400+
1401+
13611402def test_cache_read_fail_multi_producer (use_block_name ):
13621403 sch = tir .Schedule (func_multi_producer , debug_mask = "all" )
13631404 block_b = "B" if use_block_name else sch .get_block ("B" )
0 commit comments