Skip to content

Commit c4e6f96

Browse files
author
Eirene Pandi
authored
[TOPI] Add dense schedule for fp16 and fp32 using gemm (#17091)
Add a new schedule for the dense operator based on the gemm algorithm.
1 parent 0fc047c commit c4e6f96

File tree

12 files changed

+342
-28
lines changed

12 files changed

+342
-28
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,18 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
736736
plevel=12,
737737
)
738738

739+
if (
740+
target.features.is_aarch64
741+
and data.dtype in ["float16", "float32"]
742+
and weight.dtype in ["float16", "float32"]
743+
and out_type.dtype in ["float16", "float32"]
744+
):
745+
strategy.add_implementation(
746+
wrap_compute_dense(topi.arm_cpu.dense_gemm),
747+
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
748+
name="dense_gemm.arm_cpu",
749+
plevel=11,
750+
)
739751
# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
740752
strategy.add_implementation(
741753
wrap_compute_dense(topi.x86.dense_nopack),
@@ -780,6 +792,19 @@ def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
780792
lambda: None,
781793
name="matmul.arm_cpu.sme",
782794
)
795+
elif (
796+
target.features.is_aarch64
797+
and data.dtype in ["float16", "float32"]
798+
and weight.dtype in ["float16", "float32"]
799+
and out_type.dtype in ["float16", "float32"]
800+
and not (attrs.transpose_a or attrs.transpose_b)
801+
and len(data.shape) == 2
802+
):
803+
strategy.add_implementation(
804+
wrap_compute_matmul(topi.arm_cpu.dense_gemm),
805+
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
806+
name="matmul.arm_cpu.neon",
807+
)
783808
return strategy
784809

785810
logger.warning("matmul is not optimized for arm cpu.")

python/tvm/testing/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,11 @@ def _multi_gpu_exists():
871871
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64"
872872
)
873873

874+
# Mark a test as requiring the aarch64 Architecture to run.
875+
requires_aarch64 = Feature(
876+
"AArch64", "AArch64 Architecture", run_time_check=lambda: platform.machine() == "aarch64"
877+
)
878+
874879
# Mark a test as requiring the CUDA runtime.
875880
requires_cuda = Feature(
876881
"cuda",

python/tvm/topi/arm_cpu/dense.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,29 @@
1616
# under the License.
1717
"""Dense schedule for ARM CPU"""
1818
from tvm import autotvm
19-
20-
from .mprofile.dsp.dense import (
21-
dense_dsp_schedule,
22-
dense_dsp_compute,
23-
)
19+
from .mprofile.dsp.dense import dense_dsp_schedule, dense_dsp_compute
20+
from .dense_gemm import dense_gemm_compute, dense_gemm_schedule
2421

2522

2623
@autotvm.register_topi_compute("dense_dsp.arm_cpu")
2724
def dense_dsp(cfg, data, weight, bias, out_dtype):
28-
"""Compute dense_dsp with v7e-m DSP instructions."""
25+
"""Compute dense with DSP instructions."""
2926
return dense_dsp_compute(cfg, data, weight, bias=bias, out_dtype=out_dtype)
3027

3128

3229
@autotvm.register_topi_schedule("dense_dsp.arm_cpu")
3330
def schedule_dense_dsp(cfg, outs):
3431
"""Create schedule for dense_dsp"""
3532
return dense_dsp_schedule(cfg, outs)
33+
34+
35+
@autotvm.register_topi_compute("dense_gemm.arm_cpu")
36+
def dense_gemm(cfg, data, weight, bias, out_dtype, transpose_a=False, transpose_b=True):
37+
"""Compute dense using GeMM."""
38+
return dense_gemm_compute(cfg, data, weight, bias, out_dtype, transpose_a, transpose_b)
39+
40+
41+
@autotvm.register_topi_schedule("dense_gemm.arm_cpu")
42+
def schedule_dense_gemm(cfg, outs):
43+
"""Create schedule for dense using GeMM."""
44+
return dense_gemm_schedule(cfg, outs)

python/tvm/topi/arm_cpu/dense_alter_op.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
1819
"""Dense alter op definitions for the `arm_cpu` device key."""
1920

2021
import tvm
@@ -47,13 +48,11 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
4748

4849
cfg = dispatch_ctx.query(target, workload)
4950
topi_impl = workload[0]
51+
5052
if topi_impl == "matmul.arm_cpu.sme":
51-
# Pre-compute transposed weights and convert to a matmul
52-
assert isinstance(
53-
inputs[1], relay.Constant
54-
), "matmul_sme.arm_cpu requires weights be a Relay Constant"
5553

5654
weight_dtype = tinfos[1].dtype
55+
N, K = tinfos[1].shape
5756
encoded_weight = inputs[1]
5857

5958
# For dense the weights (rhs) are provided in transposed format,
@@ -65,15 +64,15 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
6564
# float16->float32 schedule the transformation currently happens at runtime
6665
# with the ARM_SME_BLOCK2_2SVLx1SVL_FP16_TRANSPOSE_INTERLEAVE intrinsic.
6766
if weight_dtype == "float32":
68-
encoded_weight = relay.const(encoded_weight.data.numpy().transpose(), weight_dtype)
67+
encoded_weight = relay.transpose(encoded_weight)
6968
transpose_b = False
7069

71-
new_weight = te.placeholder((encoded_weight.data.shape), dtype=weight_dtype)
70+
new_weight = te.placeholder(([K, N]), dtype=weight_dtype)
71+
7272
new_workload = autotvm.task.args_to_workload(
7373
[tinfos[0], new_weight, None, out_type.dtype, False, transpose_b], topi_impl
7474
)
7575
dispatch_ctx.update(target, new_workload, cfg)
76-
7776
return _make.matmul(
7877
inputs[0],
7978
encoded_weight,
@@ -82,6 +81,27 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
8281
False,
8382
transpose_b,
8483
)
84+
elif topi_impl == "dense_gemm.arm_cpu":
85+
86+
weight_dtype = tinfos[1].dtype
87+
N, K = tinfos[1].shape
88+
89+
encoded_weight = relay.transpose(inputs[1])
90+
new_weight = te.placeholder(([K, N]), dtype=weight_dtype)
91+
92+
new_workload = autotvm.task.args_to_workload(
93+
[tinfos[0], new_weight, None, out_type.dtype, False, False], topi_impl
94+
)
95+
dispatch_ctx.update(target, new_workload, cfg)
96+
97+
return _make.matmul(
98+
inputs[0],
99+
encoded_weight,
100+
attrs.units,
101+
attrs.out_dtype,
102+
False,
103+
False,
104+
)
85105

86106
# x86 schedules are used as a fallback
87107
return tvm.topi.x86.dense_alter_op._alter_dense_layout(attrs, inputs, tinfos, out_type)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-variable, too-many-locals
18+
# pylint: disable=unused-argument, redefined-builtin
19+
"""GeMM dense schedule on AArch64"""
20+
import tvm
21+
from tvm import te
22+
from tvm.topi import nn
23+
from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed, pad_dim_to_multiple
24+
from ..utils import get_const_tuple, traverse_inline
25+
from .. import tag
26+
27+
# Compute function
28+
def dense_gemm_compute(
29+
cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, transpose_b=True
30+
):
31+
"""
32+
Compute dense using GeMM.
33+
34+
Parameters
35+
----------
36+
cfg : Autotvm tuning space config file,
37+
empty in this case, but it's needed as an arg.
38+
39+
data : tvm.te.Tensor
40+
2-D with shape [M, K] or [K, M].
41+
42+
weight : tvm.te.Tensor
43+
2-D with shape [K, N] or [N, K].
44+
45+
bias : Optional[tvm.te.Tensor]
46+
1-D with shape [N]
47+
48+
49+
out_dtype : Optional[str]
50+
Specifies the output data type.
51+
52+
transpose_a : Optional[bool] = False
53+
Whether the data tensor is in transposed format.
54+
55+
transpose_b : Optional[bool] = True
56+
Whether the weight tensor is in transposed format.
57+
58+
Returns
59+
-------
60+
out : tvm.te.Tensor
61+
1-D with shape [out_dim]
62+
"""
63+
64+
if out_dtype is None:
65+
out_dtype = data.dtype
66+
M, K = get_const_tuple(data.shape) # batch, in_dim
67+
if bool(transpose_b): # out_dim
68+
(N, _) = get_const_tuple(weight.shape)
69+
else:
70+
(_, N) = get_const_tuple(weight.shape)
71+
72+
tile_M, tile_K = get_tiling_A(False, out_dtype)
73+
tile_N, _ = get_tiling_B_transformed(False, out_dtype, False)
74+
75+
M_padded, pad_M = pad_dim_to_multiple(M, tile_M)
76+
K_padded, pad_K = pad_dim_to_multiple(K, tile_K)
77+
N_padded, pad_N = pad_dim_to_multiple(N, tile_N)
78+
m_pad_after = (pad_M, pad_K)
79+
n_pad_after = (pad_N, pad_K) if transpose_b else (pad_K, pad_N)
80+
81+
if pad_M != 0 or pad_K != 0:
82+
data = nn.pad(data, pad_before=(0, 0), pad_after=m_pad_after, name="data_padded")
83+
84+
k = te.reduce_axis((0, K_padded), name="k")
85+
86+
if bool(transpose_b):
87+
weight = te.compute(
88+
(K_padded, N_padded), lambda x, y: weight[y, x], name="weight_transposed"
89+
)
90+
91+
if pad_N != 0 or pad_K != 0:
92+
weight = nn.pad(weight, pad_before=(0, 0), pad_after=n_pad_after, name="weight_padded")
93+
94+
C = te.compute(
95+
(M_padded, N_padded),
96+
lambda x, y: te.sum(
97+
data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
98+
axis=k,
99+
).astype(out_dtype),
100+
name="C",
101+
)
102+
103+
if bias is not None:
104+
C = te.compute(
105+
(M_padded, N_padded),
106+
lambda i, j: C[i, j] + bias[j].astype(out_dtype),
107+
tag=tag.BROADCAST,
108+
name="dense_biased_output",
109+
)
110+
111+
# We need to ensure that infer bound pass does not remove the padding
112+
# which is necessary for the tensorizations to work. So we need to
113+
# add a dummy reference to the padding area of the result
114+
zero = (
115+
tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
116+
- tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
117+
)
118+
119+
out = te.compute(
120+
(M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), name="dense_gemm_output"
121+
)
122+
123+
return out
124+
125+
126+
def _dense_gemm_schedule(s, out):
127+
C = out.op.input_tensors[0]
128+
A = C.op.input_tensors[0]
129+
out_type = A.dtype
130+
tile_M, tile_K = get_tiling_A(False, out_type)
131+
tile_N, _ = get_tiling_B_transformed(False, out_type, False)
132+
133+
if C.op.name == "dense_biased_output":
134+
s[C].compute_inline()
135+
C = C.op.input_tensors[0]
136+
x, y = s[C].op.axis
137+
(k,) = s[C].op.reduce_axis
138+
139+
k_outer, k_inner = s[C].split(k, factor=tile_K)
140+
x_outer, x_inner = s[C].split(x, factor=tile_M)
141+
y_outer, y_inner = s[C].split(y, factor=tile_N)
142+
y_inner_outer, y_inner_inner = s[C].split(y_inner, nparts=4)
143+
s[C].parallel(x_outer)
144+
s[C].reorder(
145+
x_outer,
146+
y_outer,
147+
k_outer,
148+
k_inner,
149+
y_inner_outer,
150+
x_inner,
151+
y_inner_inner,
152+
)
153+
s[C].unroll(y_inner_outer)
154+
s[C].unroll(x_inner)
155+
s[C].vectorize(y_inner_inner)
156+
157+
return s
158+
159+
160+
def dense_gemm_schedule(cfg, outs):
161+
"""Schedule the dense_gemm strategy"""
162+
s = te.create_schedule([x.op for x in outs])
163+
out = outs[0]
164+
x, y = out.op.axis
165+
_, inner = s[out].split(y, 4)
166+
s[out].parallel(x)
167+
s[out].vectorize(inner)
168+
169+
def _callback(op):
170+
if "dense_gemm_output" in op.name:
171+
_dense_gemm_schedule(s, op.output(0))
172+
173+
traverse_inline(s, out.op, _callback)
174+
return s

python/tvm/topi/nn/dense.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def matmul(
7070
assert (
7171
len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
7272
), "1-dim matmul is not supported yet."
73+
7374
if bias is not None:
7475
assert len(bias.shape) == 1
7576
if out_dtype is None:
@@ -229,6 +230,7 @@ def dense(
229230
output : tvm.te.Tensor
230231
2-D with shape [batch, out_dim]
231232
"""
233+
232234
return matmul(
233235
data,
234236
weight,

tests/python/frontend/keras/test_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_keras_output(in_data):
9393
def get_tvm_output(in_data, target, dev, dtype="float32"):
9494
shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, in_data)}
9595
mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout)
96-
with tvm.transform.PassContext(opt_level=2):
96+
with tvm.transform.PassContext(opt_level=3):
9797
lib = relay.build(mod, target, params=params)
9898
m = graph_executor.GraphModule(lib["default"](dev))
9999
for name, x in zip(keras_model.input_names, in_data):

0 commit comments

Comments
 (0)