Skip to content

Commit a7c0e72

Browse files
committed
check transpose condition
1 parent abaac71 commit a7c0e72

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

python/tvm/topi/x86/batch_matmul.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def batch_matmul(
120120
mcpu = tvm.target.Target.current().mcpu
121121

122122
if (
123-
target_has_vnni(mcpu)
123+
transpose_a == False
124+
and transpose_b == True
125+
and target_has_vnni(mcpu)
124126
and tensor_a.dtype == "uint8"
125127
and tensor_b.dtype == "int8"
126128
and tensor_b.shape[-2] % 16 == 0

tests/python/relay/test_op_level10.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,9 @@ def test_batch_matmul_vnni():
408408
with tvm.transform.PassContext(opt_level=3):
409409
lib = relay.build(mod, target=target)
410410

411+
asm = lib.lib.get_source("asm")
412+
assert "vpdpbusd" in asm
413+
411414
dev = tvm.device(target, 0)
412415
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
413416

0 commit comments

Comments
 (0)