Skip to content

Commit b49468d

Browse files
authored
[SME] Introduce scalable fp32 dense schedule (#16921)
This commit adds a new scalable fp32 dense schedule that calls SME intrinsics according to the SME RFC: apache/tvm-rfcs#107. Currently the schedule does not make use of predication, meaning the output from the matmul compute must be copied in a subsequent compute stage. This will be removed once support for predication is added.
1 parent cfe1711 commit b49468d

File tree

24 files changed

+1127
-122
lines changed

24 files changed

+1127
-122
lines changed

python/tvm/micro/testing/aot_test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@
6565
},
6666
)
6767

68+
AOT_APROFILE_AEM_RUNNER = AOTTestRunner(
69+
makefile="aprofile_aem",
70+
includes=[],
71+
pass_config={
72+
"tir.usmp.enable": False,
73+
# AOT test infra generates 'fake' tensor inputs which fails asserts
74+
"tir.disable_assert": True,
75+
},
76+
)
77+
6878

6979
def parametrize_aot_options(test):
7080
"""Parametrize over valid option combinations"""

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
2222
import re
2323

24+
import tvm
2425
from tvm import relay, topi, tir
26+
from tvm.tir.schedule.analysis import has_block
2527

2628
from ....auto_scheduler import is_auto_scheduler_enabled
2729
from ....meta_schedule import is_meta_schedule_enabled
@@ -639,7 +641,7 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
639641
def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
640642
"""dense arm cpu strategy"""
641643
strategy = _op.OpStrategy()
642-
data, _ = inputs
644+
data, weight = inputs
643645

644646
if target.features.has_dsp and data.dtype in ["int8", "int16"]:
645647
strategy.add_implementation(
@@ -680,6 +682,23 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
680682
plevel=11,
681683
)
682684

685+
if (
686+
target.features.has_sme
687+
and data.dtype in ["float32"]
688+
and weight.dtype in ["float32"]
689+
and out_type.dtype in ["float32"]
690+
# The schedule uses tensorization which does not work when the
691+
# reduction axis has unit iters. See
692+
# https://github.com/apache/tvm/issues/16566
693+
and data.shape[1] > 1
694+
):
695+
strategy.add_implementation(
696+
wrap_compute_dense(topi.arm_cpu.compute_matmul_sme),
697+
lambda: None,
698+
name="matmul.arm_cpu.sme",
699+
plevel=12,
700+
)
701+
683702
# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
684703
strategy.add_implementation(
685704
wrap_compute_dense(topi.x86.dense_nopack),
@@ -697,6 +716,40 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
697716
return strategy
698717

699718

719+
@matmul_strategy.register("arm_cpu")
720+
def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
721+
"""matmul arm cpu strategy"""
722+
strategy = _op.OpStrategy()
723+
data, weight = inputs
724+
725+
if (
726+
target.features.has_sme
727+
and data.dtype in ["float32"]
728+
and weight.dtype in ["float32"]
729+
and out_type.dtype in ["float32"]
730+
and not (attrs.transpose_a or attrs.transpose_b)
731+
and len(data.shape) == 2
732+
# The schedule uses tensorization which does not work when the
733+
# reduction axis has unit iters. See
734+
# https://github.com/apache/tvm/issues/16566
735+
and data.shape[1] > 1
736+
):
737+
# Ideally we should check that weight is a Relay constant, but strategy functions
738+
# don't have access to the data needed to check this.
739+
strategy.add_implementation(
740+
wrap_compute_matmul(topi.arm_cpu.compute_matmul_sme),
741+
lambda: None,
742+
name="matmul.arm_cpu.sme",
743+
)
744+
return strategy
745+
746+
logger.warning("matmul is not optimized for arm cpu.")
747+
strategy.add_implementation(
748+
wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic"
749+
)
750+
return strategy
751+
752+
700753
@conv1d_strategy.register("arm_cpu")
701754
def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
702755
"""conv1d strategy"""
@@ -737,3 +790,17 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
737790
f"Unsupported kernel layout {kernel_layout} for conv1d {layout} for arm cpu."
738791
)
739792
return strategy
793+
794+
795+
def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
796+
"""
797+
Strategy for arm_cpu STIR schedules.
798+
"""
799+
current_target = tvm.target.Target.current()
800+
801+
if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"):
802+
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
803+
return True
804+
805+
# Fallback to TE schedule for operators we have not written a special TIR schedule for
806+
return False

python/tvm/testing/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,19 @@ def _corstone300_compile_time_check():
10231023
parent_features="cmsisnn",
10241024
)
10251025

1026+
1027+
def _aprofile_aem_fvp_compile_time_check():
1028+
if shutil.which("FVP_Base_RevC-2xAEMvA") is None:
1029+
return "AProfile AEM is not available"
1030+
return True
1031+
1032+
1033+
requires_aprofile_aem_fvp = Feature(
1034+
"aprofile-aem-fvp",
1035+
"AProfile AEM FVP",
1036+
compile_time_check=_aprofile_aem_fvp_compile_time_check,
1037+
)
1038+
10261039
# Mark a test as requiring Vitis AI to run
10271040
requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI")
10281041

@@ -1205,6 +1218,10 @@ def decorator(*args):
12051218
return decorator
12061219

12071220

1221+
def skip_if_no_reference_system(func):
1222+
return skip_if_32bit(reason="Reference system unavailable in i386 container")(func)
1223+
1224+
12081225
def requires_package(*packages):
12091226
"""Mark a test as requiring python packages to run.
12101227

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@
1616
# under the License.
1717
# pylint: disable=unused-import
1818
"""Intrinsics for tensorization."""
19-
from . import arm_cpu, cuda, rocm, x86, hexagon

0 commit comments

Comments
 (0)