Skip to content

Commit abaac71

Browse files
committed
introduce a tunable knob to decide if compute_root
1 parent 8343fc0 commit abaac71

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

python/tvm/topi/x86/batch_matmul.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
2727
from .utils import target_has_vnni
2828
from .dense import dense_vnni_schedule
29+
from .injective import schedule_injective_from_existing
2930

3031

3132
def batch_matmul_vnni_compute(cfg, x, y):
@@ -50,6 +51,7 @@ def batch_matmul_vnni_compute(cfg, x, y):
5051

5152
_, a_y, _ = z.op.axis
5253
cfg.define_split("tile_y", a_y, num_outputs=2)
54+
cfg.define_knob("layout_trans_compute_root", [0, 1])
5355

5456
return z
5557

@@ -65,10 +67,14 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
6567
fused = s[O].fuse(O.op.axis[0], fused_inner)
6668
s[O].parallel(fused)
6769

68-
s[layout_trans].compute_at(s[O], fused)
69-
_, _, _, ni, ki = s[layout_trans].op.axis
70-
s[layout_trans].vectorize(ki)
71-
s[layout_trans].unroll(ni)
70+
if cfg["layout_trans_compute_root"].val:
71+
s[layout_trans].compute_root()
72+
schedule_injective_from_existing(s, layout_trans)
73+
else:
74+
s[layout_trans].compute_at(s[O], fused)
75+
_, _, _, ni, ki = s[layout_trans].op.axis
76+
s[layout_trans].vectorize(ki)
77+
s[layout_trans].unroll(ni)
7278

7379
return s
7480

0 commit comments

Comments
 (0)