2626from ..utils import traverse_inline , get_const_tuple , get_max_power2_factor
2727from .utils import target_has_vnni
2828from .dense import dense_vnni_schedule
29+ from .injective import schedule_injective_from_existing
2930
3031
3132def batch_matmul_vnni_compute (cfg , x , y ):
@@ -50,6 +51,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
5051
5152 _ , a_y , _ = z .op .axis
5253 cfg .define_split ("tile_y" , a_y , num_outputs = 2 )
54+ cfg .define_knob ("layout_trans_compute_root" , [0 , 1 ])
5355
5456 return z
5557
@@ -65,10 +67,14 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
6567 fused = s [O ].fuse (O .op .axis [0 ], fused_inner )
6668 s [O ].parallel (fused )
6769
68- s [layout_trans ].compute_at (s [O ], fused )
69- _ , _ , _ , ni , ki = s [layout_trans ].op .axis
70- s [layout_trans ].vectorize (ki )
71- s [layout_trans ].unroll (ni )
70+ if cfg ["layout_trans_compute_root" ].val :
71+ s [layout_trans ].compute_root ()
72+ schedule_injective_from_existing (s , layout_trans )
73+ else :
74+ s [layout_trans ].compute_at (s [O ], fused )
75+ _ , _ , _ , ni , ki = s [layout_trans ].op .axis
76+ s [layout_trans ].vectorize (ki )
77+ s [layout_trans ].unroll (ni )
7278
7379 return s
7480
0 commit comments