Skip to content

Commit 77b6f0e

Browse files
[Tensorize][TOPI] Add AMX Tensorizing for int8 batch matmul (#13745)
* amx int8 tensorized x86 bmm * remove the unused amx schedule * fix lint * fix lint * remove unused import * fix Instr. assert in testcase.
1 parent c2bc1ec commit 77b6f0e

File tree

4 files changed

+104
-35
lines changed

4 files changed

+104
-35
lines changed

python/tvm/relay/op/strategy/x86.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
from tvm.auto_scheduler import is_auto_scheduler_enabled
2424
from tvm.meta_schedule import is_meta_schedule_enabled
2525
from tvm.relay.ty import is_dynamic
26-
from tvm.target import Target
2726
from tvm.te import SpecializedCondition
28-
from tvm.topi.x86.utils import target_has_vnni
2927

3028
from .. import op as _op
3129
from .generic import *
@@ -618,24 +616,22 @@ def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
618616
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
619617
"""batch_matmul x86 strategy"""
620618
strategy = _op.OpStrategy()
621-
mcpu = Target.current().mcpu
622619

623620
need_auto_scheduler_layout = is_auto_scheduler_enabled()
624621
need_meta_schedule_layout = is_meta_schedule_enabled()
625622

626623
if (
627624
not attrs.transpose_a
628625
and attrs.transpose_b
629-
and target_has_vnni(mcpu)
630626
and inputs[0].dtype == "uint8"
631627
and inputs[1].dtype == "int8"
632628
and inputs[1].shape[-2] % 16 == 0
633629
and inputs[1].shape[-1] % 4 == 0
634630
):
635631
strategy.add_implementation(
636-
wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute, need_out_dtype=True),
637-
wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni),
638-
name="batch_matmul_vnni.x86",
632+
wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True),
633+
wrap_topi_schedule(topi.x86.schedule_batch_matmul_int8),
634+
name="batch_matmul_int8.x86",
639635
plevel=10,
640636
)
641637
elif is_dynamic(out_type) or need_auto_scheduler_layout or need_meta_schedule_layout:

python/tvm/topi/x86/batch_matmul.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name,too-many-locals,unused-variable
18+
# pylint: disable=unused-argument
1819
"""x86 batch_matmul operators"""
1920
import tvm
2021
from tvm import autotvm, te
@@ -24,18 +25,24 @@
2425
from .. import generic, nn
2526
from ..transform import layout_transform
2627
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
27-
from .dense import dense_vnni_schedule
28+
from .dense import dense_vnni_schedule, dense_amx_int8_schedule
2829
from .injective import schedule_injective_from_existing
30+
from .utils import target_has_vnni, target_has_amx
2931

3032

3133
@autotvm.register_topi_compute("batch_matmul_vnni.x86")
32-
def batch_matmul_vnni_compute(cfg, x, y, *_):
34+
def batch_matmul_int8_compute(cfg, x, y, *_):
3335
"""Compute for uint8 x int8 -> int32 batch_matmul"""
3436
batch, m, k = x.shape
3537
packed_y_layout = "BNK16n4k"
3638
packed_y = layout_transform(y, "BNK", packed_y_layout)
3739
_, n_o, _, n_i, _ = packed_y.shape
3840
ak = te.reduce_axis((0, k), name="k")
41+
mcpu = tvm.target.Target.current().mcpu
42+
if target_has_vnni(mcpu):
43+
attrs_info = {"schedule_rule": "batch_matmul_vnni"}
44+
else:
45+
attrs_info = None
3946

4047
z = te.compute(
4148
(batch, m, n_o * n_i),
@@ -46,14 +53,10 @@ def batch_matmul_vnni_compute(cfg, x, y, *_):
4653
),
4754
axis=ak,
4855
),
49-
tag="batch_matmul_vnni",
50-
attrs={"schedule_rule": "batch_matmul_vnni"},
56+
tag="batch_matmul_int8",
57+
attrs=attrs_info,
5158
)
5259

53-
_, a_y, _ = z.op.axis
54-
cfg.define_split("tile_y", a_y, num_outputs=2)
55-
cfg.define_knob("layout_trans_compute_root", [0, 1])
56-
5760
return z
5861

5962

@@ -67,6 +70,7 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
6770
# Parallelize over batch
6871
fused = s[O].fuse(O.op.axis[0], fused_inner)
6972
s[O].parallel(fused)
73+
cfg.define_knob("layout_trans_compute_root", [0, 1])
7074

7175
if cfg["layout_trans_compute_root"].val:
7276
s[layout_trans].compute_root()
@@ -80,6 +84,29 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
8084
return s
8185

8286

87+
def batch_matmul_amx_schedule(cfg, s, C, O, layout_trans):
88+
"""Schedule batch_matmul compute using AMX tdpbusd instruction"""
89+
# C: The output of batched GEMM
90+
# O: The output of the fused op
91+
92+
# Schedule the GEMM part
93+
s, fused_inner = dense_amx_int8_schedule(cfg, s, C, O, do_parallel=False)
94+
# Parallelize over ouuter loop
95+
fused = s[O].fuse(O.op.axis[0], fused_inner)
96+
s[O].parallel(fused)
97+
cfg.define_knob("layout_trans_compute_root", [0, 1])
98+
99+
if cfg["layout_trans_compute_root"].val:
100+
s[layout_trans].compute_root()
101+
schedule_injective_from_existing(s, layout_trans)
102+
else:
103+
_, _, _, ni, ki = s[layout_trans].op.axis
104+
s[layout_trans].vectorize(ki)
105+
s[layout_trans].unroll(ni)
106+
107+
return s
108+
109+
83110
@autotvm.register_topi_compute("batch_matmul.x86")
84111
def batch_matmul(
85112
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
@@ -202,14 +229,18 @@ def _callback(op):
202229

203230

204231
@autotvm.register_topi_schedule("batch_matmul_vnni.x86")
205-
def schedule_batch_matmul_vnni(cfg, outs):
232+
def schedule_batch_matmul_int8(cfg, outs):
206233
"""Schedule for batch_matmul_vnni"""
207234
s = te.create_schedule([x.op for x in outs])
235+
mcpu = tvm.target.Target.current().mcpu
208236

209237
def _callback(op):
210-
if "batch_matmul_vnni" in op.tag:
238+
if "batch_matmul_int8" in op.tag:
211239
layout_trans = op.input_tensors[1]
212-
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
240+
if target_has_amx(mcpu):
241+
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
242+
elif target_has_vnni(mcpu):
243+
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
213244

214245
traverse_inline(s, outs[0].op, _callback)
215246
return s

python/tvm/topi/x86/dense.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def split_k(out, rd_axis):
436436
cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128)
437437
return cfg["tile_k"].apply(s, out, rd_axis)
438438

439-
a_x, a_y = C.op.axis
439+
a_x, a_y = C.op.axis[-2:]
440440
(a_k,) = C.op.reduce_axis
441441
CF = s.cache_write(C, "amx.tmm")
442442

@@ -447,16 +447,16 @@ def split_k(out, rd_axis):
447447
s[CF].compute_at(s[C], a_yo)
448448

449449
(a_k_f,) = CF.op.reduce_axis
450-
a_x_f, a_y_f = CF.op.axis
450+
a_x_f, a_y_f = CF.op.axis[-2:]
451451

452452
a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)
453453

454454
a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
455455
a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
456456
s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f)
457457

458-
(m, k) = CF.op.input_tensors[0].shape
459-
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape
458+
(m, k) = CF.op.input_tensors[0].shape[-2:]
459+
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape[-4:]
460460
n = n * n_i
461461

462462
s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
@@ -479,19 +479,6 @@ def split_k(out, rd_axis):
479479
return s, fused
480480

481481

482-
@autotvm.register_topi_schedule("dense_amx_int8.x86")
483-
def schedule_dense_amx_int8(cfg, outs):
484-
"""Create a schedule for dense_amx_int8"""
485-
s = te.create_schedule([x.op for x in outs])
486-
487-
def _callback(op):
488-
if "dense_amx_int8" in op.tag:
489-
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
490-
491-
traverse_inline(s, outs[0].op, _callback)
492-
return s
493-
494-
495482
def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
496483
"""Compute matmul/dense using a BLAS library"""
497484
M, K = get_const_tuple(tensor_a.shape)

tests/python/relay/test_op_level10.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,61 @@ def test_batch_matmul_vnni(b, m, n, k):
520520
np.testing.assert_equal(out, ref)
521521

522522

523+
@pytest.mark.skip("skip due to AMX feature not avaliable yet")
524+
@pytest.mark.parametrize(
525+
"b,m,n,k",
526+
[
527+
(16, 32, 32, 128),
528+
(16, 32, 32, 127),
529+
(16, 32, 31, 128),
530+
],
531+
)
532+
def test_batch_matmul_amx(b, m, n, k):
533+
amx_init = tvm.get_global_func("runtime.amx_init")
534+
amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig")
535+
assert amx_init()
536+
assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns.
537+
538+
x_shape = (b, m, k)
539+
y_shape = (b, n, k)
540+
z_shape = (b, m, n)
541+
542+
for lhs_dtype in ["uint8", "int8"]:
543+
x = relay.var("x", shape=x_shape, dtype=lhs_dtype)
544+
y = relay.var("y", shape=y_shape, dtype="int8")
545+
z = relay.var("z", shape=z_shape, dtype="int32")
546+
bmm = relay.nn.batch_matmul(x, y, out_dtype="int32")
547+
out = bmm + z
548+
mod = tvm.IRModule.from_expr(out)
549+
550+
target = "llvm -mcpu=sapphirerapids"
551+
with tvm.transform.PassContext(opt_level=3):
552+
lib = relay.build(mod, target=target)
553+
554+
asm = lib.lib.get_source("asm")
555+
assert "tilezero" in asm
556+
assert "tileloaddt1" in asm
557+
assert "tdpbusd" in asm
558+
assert "tilestored" in asm
559+
560+
dev = tvm.device(target, 0)
561+
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
562+
563+
x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype)
564+
y_np = np.random.uniform(1, 10, size=y_shape).astype("int8")
565+
z_np = np.random.uniform(1, 10, size=z_shape).astype("int32")
566+
567+
runtime.set_input("x", x_np)
568+
runtime.set_input("y", y_np)
569+
runtime.set_input("z", z_np)
570+
runtime.run()
571+
572+
out = runtime.get_output(0).numpy()
573+
ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + z_np
574+
575+
np.testing.assert_equal(out, ref)
576+
577+
523578
@pytest.mark.skip("Requires GFX10 AMDGPU")
524579
def test_batch_matmul_rocm_sdot4():
525580
x_shape = (16, 32, 96)

0 commit comments

Comments
 (0)