Skip to content

Commit b7d35df

Browse files
author
Eirene Pandi
committed
[TOPI] Add dense schedule for fp16 and fp32 using gemm
Add a new schedule for the dense operator based on the gemm algorithm. Change-Id: Iaf4423d21d20b5813c77a0a27c4751f8cbd1d8b8
1 parent ab02979 commit b7d35df

File tree

10 files changed

+340
-6
lines changed

10 files changed

+340
-6
lines changed

cmake/config.cmake

100644100755
File mode changed.

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,17 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
729729
plevel=12,
730730
)
731731

732+
if (
733+
data.dtype in ["float16", "float32"]
734+
and weight.dtype in ["float16", "float32"]
735+
and out_type.dtype in ["float16", "float32"]
736+
):
737+
strategy.add_implementation(
738+
wrap_compute_dense(topi.arm_cpu.dense_gemm),
739+
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
740+
name="dense_gemm.arm_cpu",
741+
plevel=11,
742+
)
732743
# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
733744
strategy.add_implementation(
734745
wrap_compute_dense(topi.x86.dense_nopack),
@@ -773,6 +784,18 @@ def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
773784
lambda: None,
774785
name="matmul.arm_cpu.sme",
775786
)
787+
elif (
788+
data.dtype in ["float16", "float32"]
789+
and weight.dtype in ["float16", "float32"]
790+
and out_type.dtype in ["float16", "float32"]
791+
and not (attrs.transpose_a or attrs.transpose_b)
792+
and len(data.shape) == 2
793+
):
794+
strategy.add_implementation(
795+
wrap_compute_matmul(topi.arm_cpu.dense_gemm),
796+
wrap_topi_schedule(topi.arm_cpu.schedule_dense_gemm),
797+
name="matmul.arm_cpu.neon",
798+
)
776799
return strategy
777800

778801
logger.warning("matmul is not optimized for arm cpu.")

python/tvm/topi/arm_cpu/dense.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,29 @@
1616
# under the License.
1717
"""Dense schedule for ARM CPU"""
1818
from tvm import autotvm
19-
20-
from .mprofile.dsp.dense import (
21-
dense_dsp_schedule,
22-
dense_dsp_compute,
23-
)
19+
from .mprofile.dsp.dense import dense_dsp_schedule, dense_dsp_compute
20+
from .dense_gemm import dense_gemm_compute, dense_gemm_schedule
2421

2522

2623
@autotvm.register_topi_compute("dense_dsp.arm_cpu")
2724
def dense_dsp(cfg, data, weight, bias, out_dtype):
28-
"""Compute dense_dsp with v7e-m DSP instructions."""
25+
"""Compute dense with DSP instructions."""
2926
return dense_dsp_compute(cfg, data, weight, bias=bias, out_dtype=out_dtype)
3027

3128

3229
@autotvm.register_topi_schedule("dense_dsp.arm_cpu")
3330
def schedule_dense_dsp(cfg, outs):
3431
"""Create schedule for dense_dsp"""
3532
return dense_dsp_schedule(cfg, outs)
33+
34+
35+
@autotvm.register_topi_compute("dense_gemm.arm_cpu")
36+
def dense_gemm(cfg, data, weight, bias, out_dtype, transpose_a=False, transpose_b=True):
37+
"""Compute dense using GeMM."""
38+
return dense_gemm_compute(cfg, data, weight, bias, out_dtype, transpose_a, transpose_b)
39+
40+
41+
@autotvm.register_topi_schedule("dense_gemm.arm_cpu")
42+
def schedule_dense_gemm(cfg, outs):
43+
"""Create schedule for dense using GeMM."""
44+
return dense_gemm_schedule(cfg, outs)

python/tvm/topi/arm_cpu/dense_alter_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
4747

4848
cfg = dispatch_ctx.query(target, workload)
4949
topi_impl = workload[0]
50+
5051
if topi_impl == "matmul.arm_cpu.sme":
5152
# Pre-compute transposed weights and convert to a matmul
5253
assert isinstance(
@@ -82,6 +83,31 @@ def _alter_dense(attrs, inputs, tinfos, out_type):
8283
False,
8384
transpose_b,
8485
)
86+
elif topi_impl == "dense_gemm.arm_cpu":
87+
# Pre-compute transposed weights and convert to a matmul
88+
assert isinstance(
89+
inputs[1], relay.Constant
90+
), "dense_gemm.arm_cpu requires weights be a Relay Constant"
91+
92+
weight_dtype = tinfos[1].dtype
93+
weight_data = inputs[1].data.numpy()
94+
interleaved = weight_data.transpose()
95+
encoded_weight = relay.const(interleaved, weight_dtype)
96+
97+
new_weight = te.placeholder((weight_data.shape), dtype=weight_dtype)
98+
new_workload = autotvm.task.args_to_workload(
99+
[tinfos[0], new_weight, None, out_type.dtype], topi_impl
100+
)
101+
dispatch_ctx.update(target, new_workload, cfg)
102+
103+
return relay.nn.matmul(
104+
inputs[0],
105+
encoded_weight,
106+
units=attrs.units,
107+
out_dtype=attrs.out_dtype,
108+
transpose_a=False,
109+
transpose_b=False,
110+
)
85111

86112
# x86 schedules are used as a fallback
87113
return tvm.topi.x86.dense_alter_op._alter_dense_layout(attrs, inputs, tinfos, out_type)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-variable, too-many-locals
18+
"""GEMM Convolution schedule on AArch64"""
19+
import tvm
20+
from tvm.target import Target
21+
from tvm import te
22+
from tvm.topi import nn
23+
from tvm.topi.arm_cpu.arm_utils import get_tiling_A, get_tiling_B_transformed
24+
from ..utils import get_const_tuple, traverse_inline
25+
from ..nn.utils import get_pad_tuple
26+
from .. import tag
27+
28+
# Compute function
29+
def dense_gemm_compute(
30+
cfg, data, weight, bias=None, out_dtype=None, transpose_a=False, transpose_b=True
31+
):
32+
"""
33+
Compute dense using GeMM.
34+
35+
transpose_b : Optional[bool] = True
36+
Whether the weight tensor is in transposed format.
37+
"""
38+
39+
if out_dtype is None:
40+
out_dtype = data.dtype
41+
M, K = get_const_tuple(data.shape) # batch, in_dim
42+
if bool(transpose_b): # out_dim
43+
(N, _) = get_const_tuple(weight.shape)
44+
else:
45+
(_, N) = get_const_tuple(weight.shape)
46+
47+
in_dtype = data.dtype
48+
49+
tile_M, tile_K_A = get_tiling_A(False, in_dtype)
50+
tile_N, tile_K_B = get_tiling_B_transformed(False, out_dtype, False)
51+
52+
pad_M = 0
53+
pad_K = 0
54+
pad_N = 0
55+
56+
if M % tile_M != 0:
57+
pad_M = tile_M - (M % tile_M)
58+
59+
if K % tile_K_A != 0:
60+
pad_K = tile_K_A - (K % tile_K_A)
61+
62+
M_padded = M + pad_M
63+
K_padded = K + pad_K
64+
k = te.reduce_axis((0, K_padded), name="k")
65+
66+
pad_before = (0, 0)
67+
pad_after = (pad_M, pad_K)
68+
69+
if pad_K != 0:
70+
data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, name="A_padded_K")
71+
elif pad_M != 0:
72+
data = nn.pad(data, pad_before=pad_before, pad_after=pad_after, name="A_padded_M")
73+
74+
if N % tile_N != 0:
75+
pad_N = tile_N - (N % tile_N)
76+
N_padded = N + pad_N
77+
78+
if bool(transpose_b):
79+
weight = te.compute(
80+
(K_padded, N_padded), lambda x, y: weight[y, x], name="weight_transposed"
81+
)
82+
83+
if pad_K != 0 or pad_N != 0:
84+
weight = nn.pad(weight, pad_before=(0, 0), pad_after=(pad_N, pad_K), name="weight_padded")
85+
86+
C = te.compute(
87+
(M_padded, N_padded),
88+
lambda x, y: te.sum(
89+
data[x, k].astype(out_dtype) * weight[k, y].astype(out_dtype),
90+
axis=k,
91+
).astype(out_dtype),
92+
name="C",
93+
)
94+
95+
if bias is not None:
96+
C = te.compute(
97+
(M_padded, N_padded),
98+
lambda i, j: C[i, j] + bias[j].astype(out_dtype),
99+
tag=tag.BROADCAST,
100+
name="dense_biased_output",
101+
)
102+
103+
zero = (
104+
tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
105+
- tvm.tir.const(1, C.dtype) * C[0, N_padded - 1]
106+
)
107+
108+
out = te.compute(
109+
(M, N), lambda x, y: (C[x, y] + zero).astype(out_dtype), name="dense_gemm_output"
110+
)
111+
112+
return out
113+
114+
115+
def _dense_gemm_schedule_template(s, out):
116+
C = out.op.input_tensors[0]
117+
A = C.op.input_tensors[0]
118+
in_type = A.dtype
119+
y_tile_size, _ = get_tiling_B_transformed(False, in_type)
120+
if C.op.name == "dense_biased_output":
121+
s[C].compute_inline()
122+
C = C.op.input_tensors[0]
123+
x, y = s[C].op.axis
124+
(k,) = s[C].op.reduce_axis
125+
k_outer, k_inner = s[C].split(k, factor=4)
126+
x_outer, x_inner = s[C].split(x, factor=4)
127+
y_outer, y_inner = s[C].split(y, factor=y_tile_size)
128+
s[C].parallel(x_outer)
129+
s[C].reorder(
130+
x_outer,
131+
y_outer,
132+
k_outer,
133+
k_inner,
134+
x_inner,
135+
y_inner,
136+
)
137+
s[C].unroll(x_inner)
138+
s[C].vectorize(y_inner)
139+
140+
return s
141+
142+
143+
def dense_gemm_schedule(cfg, outs):
144+
"""Schedule the dense_gemm strategy"""
145+
s = te.create_schedule([x.op for x in outs])
146+
out = outs[0]
147+
x, y = out.op.axis
148+
_, inner = s[out].split(y, 4)
149+
s[out].parallel(x)
150+
s[out].vectorize(inner)
151+
152+
def _callback(op):
153+
if "dense_gemm_output" in op.name:
154+
_dense_gemm_schedule_template(s, op.output(0))
155+
156+
traverse_inline(s, out.op, _callback)
157+
return s

python/tvm/topi/nn/dense.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def matmul(
7070
assert (
7171
len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
7272
), "1-dim matmul is not supported yet."
73+
7374
if bias is not None:
7475
assert len(bias.shape) == 1
7576
if out_dtype is None:
@@ -229,6 +230,7 @@ def dense(
229230
output : tvm.te.Tensor
230231
2-D with shape [batch, out_dim]
231232
"""
233+
232234
return matmul(
233235
data,
234236
weight,

python/tvm/topi/x86/dense.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,46 @@ def _callback(op):
283283
return s
284284

285285

286+
@autotvm.register_topi_compute("dense_simple.x86")
287+
def dense_simple(cfg, data, weight, bias=None, out_dtype=None):
288+
"""Compute dense with transformed weight."""
289+
if out_dtype is None:
290+
out_dtype = data.dtype
291+
M, K = get_const_tuple(data.shape) # batch, in_dim
292+
N, _ = get_const_tuple(weight.shape) # out_dim
293+
k = te.reduce_axis((0, K), name="k")
294+
C = te.compute(
295+
(M, N),
296+
lambda i, j: te.sum(data[i, k] * weight[k, j]),
297+
tag="dense_simple",
298+
)
299+
if bias is not None:
300+
C = te.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
301+
302+
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
303+
build_mod = tvm.build(C, target=target)
304+
buffer_size = 128
305+
np_ones = np.ones((buffer_size,)).astype("float32")
306+
_test_accuracy(np_ones, np_ones, build_mod)
307+
return C
308+
309+
# Linear transformation
310+
linear_output = np.dot(data, weight.T) + bias
311+
312+
313+
@autotvm.register_topi_schedule("dense_simple.x86")
314+
def schedule_dense_pack(cfg, outs):
315+
"""Create the schedule for dense_simple"""
316+
s = te.create_schedule([x.op for x in outs])
317+
318+
def _callback(op):
319+
if "dense_simple" in op.tag:
320+
_schedule_dense_simple_template(cfg, s, op.output(0), outs[0])
321+
322+
traverse_inline(s, outs[0].op, _callback)
323+
return s
324+
325+
286326
@autotvm.register_topi_compute("dense_int8.x86")
287327
def dense_int8(cfg, data, weight, bias=None, out_dtype=None):
288328
"""Compute for uint8 x int8 -> int32 dense"""

tests/python/relay/test_dense.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import tvm
2+
from tvm import relay
3+
from tvm.testing import assert_allclose
4+
import numpy as np
5+
from tvm.ir.instrument import pass_instrument
6+
7+
8+
def _test_accuracy(input_values, output_values, build_mod):
9+
10+
dev = tvm.cpu(0)
11+
12+
input_buf = tvm.nd.array(input_values, device=dev)
13+
rt = tvm.contrib.graph_executor.GraphModule(build_mod["default"](dev))
14+
rt.set_input("data", input_buf)
15+
rt.run()
16+
out = rt.get_output(0)
17+
18+
tvm.testing.assert_allclose(out.numpy(), output_values)
19+
20+
21+
# Define input shape and data type
22+
data_size = (64, 64)
23+
data_shape = data_size # Input shape
24+
data_type = "float32" # Data type
25+
weight_shape = data_size
26+
27+
# Create Relay input variable
28+
d = relay.var("data", shape=data_shape, dtype=data_type)
29+
w1 = np.ones(weight_shape, dtype=data_type)
30+
w = relay.const(w1)
31+
32+
# Create Relay dense layer
33+
y = relay.nn.dense(d, w)
34+
35+
# Create Relay module
36+
mod = tvm.IRModule()
37+
38+
# Define a Relay function with the dense layer
39+
mod["main"] = relay.Function([d], y)
40+
41+
# Compile the Relay module
42+
target = "llvm -mtriple=aarch64-linux-gnu -device=arm_cpu -mattr=+v8.2a,+neon" # Example target, you can change this to your desired target
43+
lib = relay.build(mod, target=target, params=None)
44+
45+
in_np = np.random.uniform(size=(data_size)).astype(data_type)
46+
out_np = np.array(np.matmul(in_np, w1.T))
47+
48+
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
49+
_test_accuracy(in_np, out_np, lib)

0 commit comments

Comments
 (0)