Skip to content

Commit aa7328f

Browse files
committed
check dot prod availablity in batch matmul schedule
1 parent 0d3b16d commit aa7328f

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

python/tvm/topi/cuda/batch_matmul.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,16 @@ def _schedule_batch_matmul_int8(cfg, s, output):
368368
ko, ki = s[batch_matmul_cache].split(ko, factor=4)
369369
ko, kt = cfg["tile_k"].apply(s, batch_matmul_cache, ko)
370370
# dp4a tensorize
371-
s[batch_matmul_cache].tensorize(ki, _dp4a)
371+
372+
target = tvm.target.Target.current(allow_none=False)
373+
do_tensorize = True
374+
375+
if "vulkan" in target.keys:
376+
do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
377+
378+
if do_tensorize:
379+
dtypes = (input_x.dtype, input_y.dtype)
380+
s[batch_matmul_cache].tensorize(ki, dp4a("shared", "shared", "local", dtypes))
372381

373382
# tile axis
374383
f, m, n = batch_matmul_op.axis

tests/python/topi/python/test_topi_batch_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def check_device(device):
128128
f(a, b, c)
129129
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
130130

131-
for device in ["cuda"]:
131+
for device in ["cuda", "vulkan -from_device=0"]:
132132
check_device(device)
133133

134134

0 commit comments

Comments
 (0)