Skip to content

Commit 41a1f04

Browse files
committed
Fix tests and rebase
Change-Id: Iaddeb046bdecb0352a067174f6e6e4be335e94fd
1 parent b9fcab0 commit 41a1f04

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

python/tvm/topi/arm_cpu/conv2d.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tvm.script import tir as T
2525
import tvm.contrib.nnpack
2626
from tvm.tir.schedule.analysis import has_block
27-
from tvm.topi.arm_cpu.matmul import _get_transpose_interleave_intrin_name
2827

2928
from ..utils import traverse_inline, get_const_tuple
3029
from .. import nn
@@ -776,10 +775,6 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
776775
get_transpose_interleave_intrin_name,
777776
)
778777

779-
transpose_interleave_intrin_name = _get_transpose_interleave_intrin_name(
780-
in_dtype, out_dtype
781-
)
782-
783778
# Interleave the padded im2col matrix utilizing the matrix tile
784779
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
785780
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m))
@@ -788,7 +783,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
788783
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
789784
sch.parallel(b)
790785
sch.reorder(b, ko, mo, ki, mi)
791-
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))
786+
sch.tensorize(
787+
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
788+
)
792789

793790
# Interleave the padded weights matrix utilizing the matrix tile
794791
if in_dtype == "float16":
@@ -798,7 +795,9 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
798795
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
799796
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
800797
sch.reorder(ko, no, ki, ni)
801-
sch.tensorize(ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded))
798+
sch.tensorize(
799+
ki, get_transpose_interleave_intrin_name(in_dtype, out_dtype, M_padded, K_padded)
800+
)
802801

803802
# Split and reorder the loops of the GeMM for tensorization
804803
b, m, n, k = sch.get_loops(gemm_block)
@@ -821,7 +820,7 @@ def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
821820
)
822821
tvm.tir.TensorIntrin.register(
823822
sme_gemm_interleaved_intrin_name,
824-
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded, in_dtype),
823+
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M_padded, K_padded, in_dtype),
825824
override=True,
826825
)
827826
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)

python/tvm/topi/arm_cpu/conv2d_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,15 @@ def compute_conv2d_gemm_without_weight_transform(
143143
N_padded = N + pad_N
144144

145145
pad_before = (0, 0, 0)
146-
pad_after = (0, 0, 0) if use_sme else (0, pad_M, pad_K)
146+
pad_after = (0, pad_M, pad_K)
147147

148148
if pad_K != 0:
149149
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
150150
elif pad_M != 0:
151151
A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")
152152

153153
idxm = tvm.tir.indexmod
154-
k = te.reduce_axis((0, K), "k")
154+
k = te.reduce_axis((0, K if use_explicit_predication else K_padded), "k")
155155

156156
# Determine matrix multiplication compute definition
157157
target = Target.current(allow_none=False)

tests/python/topi/test_topi_conv2d_nhwc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def test_conv2d_nhwc_gemm(device, ref_data, dtype, stride, padding, dilation):
176176
if target.features.has_sme and a_np.shape[0] > 1:
177177
pytest.skip(f"Conv2d with batches > 1 targeting SME not implemented.")
178178

179+
if target.features.has_sme and (a_np.shape[3] * w_np.shape[0] * w_np.shape[1]) <= 1:
180+
pytest.skip(f"Conv2d with unit reduction dimension targeting SME not supported.")
181+
179182
# SME schedule always outputs float32 results, regardless of input dtype.
180183
# Otherwise, output dtype is the same as input dtype.
181184
out_dtype = "float32" if target.features.has_sme else dtype

0 commit comments

Comments
 (0)