Skip to content

Commit 5b0933d

Browse files
committed
[Hexagon] Softmax slice op initial version
Resolve merge conflict in utils.py
1 parent 1e0e954 commit 5b0933d

File tree

4 files changed

+248
-0
lines changed

4 files changed

+248
-0
lines changed

python/tvm/topi/hexagon/slice_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919

2020
from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
2121
from .add_subtract_multiply import *
22+
from .softmax_slice import *
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Hexagon slice softmax compute and schedule"""
18+
19+
import typing
20+
21+
from tvm import te, tir, topi
22+
from ..utils import get_layout_transform_fn
23+
24+
25+
def softmax_compute(in_tensor):
26+
"""
27+
Compute for slice softmax op for hexagon.
28+
This op makes the following assumptions:
29+
1. This op is written for a sliced softmax operation.
30+
2. The input is assumed to be in NC layout.
31+
"""
32+
return topi.nn.softmax(in_tensor, axis=1)
33+
34+
35+
def softmax_stir_schedule(
36+
out: te.Tensor, inp: te.Tensor, out_layout: typing.Callable, in_layout: typing.Callable
37+
):
38+
"""
39+
STIR schedule definition for the compute of softmax
40+
"""
41+
42+
in_layout = get_layout_transform_fn(in_layout)
43+
out_layout = get_layout_transform_fn(out_layout)
44+
45+
func = te.create_prim_func([inp, out])
46+
sch = tir.Schedule(func, debug_mask="all")
47+
48+
max_tensor = sch.get_block("T_softmax_maxelem")
49+
exp_tensor = sch.get_block("T_softmax_exp")
50+
sum_tensor = sch.get_block("T_softmax_expsum")
51+
out_tensor = sch.get_block("T_softmax_norm")
52+
53+
sch.transform_layout(max_tensor, inp.name, in_layout)
54+
sch.transform_layout(out_tensor, out.name, out_layout)
55+
56+
_, c_inner = sch.get_loops(max_tensor)
57+
_, c_inner_i = sch.split(c_inner, [None, 64])
58+
rf_max = sch.rfactor(c_inner_i, 0)
59+
_, _, max_inner = sch.get_loops(rf_max)
60+
sch.vectorize(max_inner)
61+
62+
_, loopi = sch.get_loops(exp_tensor)
63+
_, loopi_i = sch.split(loopi, [None, 512])
64+
sch.vectorize(loopi_i)
65+
66+
_, c_sum_inner = sch.get_loops(sum_tensor)
67+
_, c_sum_inner_i = sch.split(c_sum_inner, [None, 64])
68+
rf_sum = sch.rfactor(c_sum_inner_i, 0)
69+
_, _, sum_inner = sch.get_loops(rf_sum)
70+
sch.vectorize(sum_inner)
71+
72+
_, c_out_inner = sch.get_loops(out_tensor)
73+
_, c_out_inner_i = sch.split(c_out_inner, [None, 512])
74+
sch.vectorize(c_out_inner_i)
75+
76+
return sch

python/tvm/topi/hexagon/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
# pylint: disable=invalid-name
19+
20+
1821
"""Common hexagon specific utilities"""
1922
from tvm import te
2023

@@ -39,6 +42,26 @@ def nhwc_8h2w32c2w_1d(n, h, w, c):
3942
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]
4043

4144

45+
def nhwc_4h4w32c_1d(n, h, w, c):
46+
"""Return index map for nhwc_4h4232c 1d layout"""
47+
return [n, h // 4, w // 4, c // 32, h % 4, w % 4, c % 32]
48+
49+
50+
def nhwc_4h4w32c_2d(n, h, w, c):
51+
"""Return index map for nhwc_4h4w32c 2d layout"""
52+
return [n, h // 4, w // 4, c // 32, te.AXIS_SEPARATOR, h % 4, w % 4, c % 32]
53+
54+
55+
def nc_512c_1d(n, c):
56+
"""Return index map for nc_512c 1d layout"""
57+
return [n, c // 512, c % 512]
58+
59+
60+
def nc_512c_2d(n, c):
61+
"""Return index map for nc_512c 2d layout"""
62+
return [n, c // 512, te.AXIS_SEPARATOR, c % 512]
63+
64+
4265
def get_layout_transform_fn(layout):
4366
"""Return index map function as per the layout string"""
4467
if layout == "nhwc-8h2w32c2w-2d":
@@ -49,4 +72,12 @@ def get_layout_transform_fn(layout):
4972
return n11c_1024c_2d
5073
if layout == "n11c-1024c-1d":
5174
return n11c_1024c_1d
75+
if layout == "nhwc-4h4w32c-2d":
76+
return nhwc_4h4w32c_2d
77+
if layout == "nhwc-4h4w32c-1d":
78+
return nhwc_4h4w32c_1d
79+
if layout == "nc-512c-2d":
80+
return nc_512c_2d
81+
if layout == "nc-512c-1d":
82+
return nc_512c_1d
5283
raise RuntimeError(f"Unexpected layout '{layout}'")
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pytest
19+
import numpy as np
20+
from tvm import te, topi
21+
22+
import tvm.testing
23+
from tvm.topi import testing
24+
from tvm.contrib.hexagon.build import HexagonLauncher
25+
26+
import tvm.topi.hexagon.slice_ops as sl
27+
from .infrastructure import allocate_hexagon_array
28+
29+
30+
def transform_numpy(arr_np, layout):
31+
32+
if layout in ["nc-512c-2d"]:
33+
N, C = arr_np.shape
34+
return arr_np.reshape([N, C // 512, 512])
35+
raise RuntimeError(f"Unexpected layout '{layout}'")
36+
37+
38+
@tvm.testing.fixture
39+
def input_np(input_shape, dtype):
40+
return (np.random.uniform(size=input_shape)).astype(dtype)
41+
42+
43+
@tvm.testing.fixture
44+
def transformed_expected_output_np(expected_output_np, output_layout):
45+
return transform_numpy(expected_output_np, output_layout)
46+
47+
48+
@tvm.testing.fixture
49+
def transformed_input_np(input_np, input_layout):
50+
return transform_numpy(input_np, input_layout)
51+
52+
53+
class Basesoftmax2d:
54+
55+
input_shape, input_layout, output_layout, axis_sep = tvm.testing.parameters(
56+
((1, 1024), "nc-512c-2d", "nc-512c-2d", [2])
57+
)
58+
dtype = tvm.testing.parameter("float32")
59+
working_scope = tvm.testing.parameter("global.vtcm")
60+
61+
62+
class TestSoftmax2d(Basesoftmax2d):
63+
@tvm.testing.fixture
64+
def expected_output_np(self, input_np):
65+
if len(input_np.shape) == 2:
66+
ref_np_2d = tvm.topi.testing.softmax_python(input_np)
67+
return ref_np_2d
68+
raise RuntimeError(f"Unexpected input shape '{input_np.shape}'")
69+
70+
@tvm.testing.requires_hexagon
71+
def test_softmax_f32(
72+
self,
73+
dtype,
74+
input_layout,
75+
output_layout,
76+
input_shape,
77+
input_np,
78+
transformed_input_np,
79+
transformed_expected_output_np,
80+
expected_output_np,
81+
working_scope,
82+
axis_sep,
83+
hexagon_session,
84+
):
85+
86+
target_hexagon = tvm.target.hexagon(
87+
"v69",
88+
llvm_options="--disable-loop-unrolling-pass",
89+
)
90+
A = te.placeholder(input_shape, name="A", dtype=dtype)
91+
92+
O = sl.softmax_compute(A)
93+
94+
if input_layout == "nc-512c-2d":
95+
tir_s = sl.softmax_stir_schedule(O, A, output_layout, input_layout)
96+
sch = tir_s.mod
97+
else:
98+
raise RuntimeError(f"Unexpected input layout '{input_layout}'")
99+
100+
with tvm.transform.PassContext(
101+
opt_level=3,
102+
config={
103+
"tir.LoopPartition": {"partition_const_loop": True},
104+
},
105+
):
106+
107+
func = tvm.build(
108+
sch,
109+
[A, O],
110+
tvm.target.Target(target_hexagon, host=target_hexagon),
111+
name="softmax_slice",
112+
)
113+
114+
input_arr = allocate_hexagon_array(
115+
hexagon_session.device,
116+
data=transformed_input_np,
117+
axis_separators=axis_sep,
118+
mem_scope=working_scope,
119+
)
120+
121+
output_arr = allocate_hexagon_array(
122+
hexagon_session.device,
123+
tensor_shape=transformed_expected_output_np.shape,
124+
dtype=transformed_expected_output_np.dtype,
125+
axis_separators=axis_sep,
126+
mem_scope=working_scope,
127+
)
128+
129+
mod = hexagon_session.load_module(func)
130+
mod(input_arr, output_arr)
131+
132+
n, c = input_np.shape
133+
output_np = output_arr.numpy().reshape(1, c // 512, 512)
134+
135+
np.testing.assert_allclose(output_np, transformed_expected_output_np, rtol=1e-4, atol=1e-4)
136+
137+
138+
if __name__ == "__main__":
139+
140+
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)