Skip to content

Commit 9d98da2

Browse files
authored
[Hexagon] Implement avg_pool2d slice op (#11417)
* Implement avg_pool2d slice op * Address review comments and fix the STIR schedule * Fix formatting issues * Address pylint errors * Additional formatting issues * more pylint fixes * Changed arch version to v68 for now * Changing arch version back to v69 * Move the test to tests/python/contrib/test_hexagon/topi
1 parent a5df283 commit 9d98da2

File tree

5 files changed

+604
-0
lines changed

5 files changed

+604
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
""" Computes and Schedules for Hexagon slice ops. """
19+
20+
# pylint: disable=wildcard-import
21+
22+
from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-variable, unused-argument, too-many-locals
18+
19+
""" Compute and schedule for avg_pool2d slice op
20+
21+
Please note the following assumptions made by the implementation:
22+
23+
1) The input must be padded in advance to account for 'padding'. In addition,
24+
both input and output must be padded as per the physical buffer layout.
25+
2) The current implementation assumes 'count_include_pad' to be 'True'. It can be
26+
modified to support 'False' case but the element count for the pooling window
27+
must be pre-computed and provided as an input to reduce the run-time overhead.
28+
3) 'padding' is ignored. It must be handled outside of the sliced op.
29+
4) Please note that this implementation will not work if the output includes any
30+
physical layout related padding as it can result into out-of-bound access
31+
for the input.
32+
"""
33+
34+
from tvm import te
35+
from tvm import tir
36+
from ..utils import get_layout_transform_fn
37+
38+
39+
def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
40+
"""Validate output shape"""
41+
_, oh, ow, _ = out_shape
42+
_, ih, iw, _ = in_shape
43+
kh, kw = kernel
44+
sh, sw = stride
45+
dh, dw = dilation
46+
if ih < (oh - 1) * sh + dh * (kh - 1) + 1:
47+
raise RuntimeError("Output height is too large")
48+
if iw < (ow - 1) * sw + dw * (kw - 1) + 1:
49+
raise RuntimeError("Output width is too large")
50+
51+
52+
def avg_pool2d_compute(A, out_shape, kernel, stride, dilation):
53+
"""avg_pool2d compute"""
54+
kh, kw = kernel
55+
rh = te.reduce_axis((0, kh), name="rh")
56+
rw = te.reduce_axis((0, kw), name="rw")
57+
ob, oh, ow, oc = out_shape
58+
if isinstance(ob, int):
59+
validate_out_shape(out_shape, A.shape, kernel, stride, dilation)
60+
61+
sh, sw = stride
62+
dh, dw = dilation
63+
InvArea = float(1) / (kh * kw)
64+
65+
Sum = te.compute(
66+
out_shape,
67+
lambda b, h, w, c: te.sum(
68+
A[b, h * sh + dh * rh, w * sw + dw * rw, c].astype("float32"), axis=[rh, rw]
69+
),
70+
name="sum",
71+
)
72+
Avg = te.compute(
73+
out_shape, lambda b, h, w, c: (Sum[b, h, w, c] * InvArea).astype(A.dtype), name="avg"
74+
)
75+
return Avg
76+
77+
78+
def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
79+
"""Schedule for input and output layout nhwc-8h2w32c2w"""
80+
func = te.create_prim_func([ins, outs])
81+
s = tir.Schedule(func)
82+
Sum = s.get_block("sum")
83+
Avg = s.get_block("avg")
84+
85+
input_transform_fn = get_layout_transform_fn(input_layout)
86+
output_transform_fn = get_layout_transform_fn(output_layout)
87+
s.transform_layout(Sum, ("read", 0), input_transform_fn)
88+
s.transform_layout(Avg, ("write", 0), output_transform_fn)
89+
90+
# Schedule 'Avg'
91+
n, h, w, c = s.get_loops(Avg)
92+
ho, hi = s.split(h, [None, 8])
93+
wo, wi = s.split(w, [None, 4])
94+
wio, wii = s.split(wi, [None, 2])
95+
co, ci = s.split(c, [None, 32])
96+
s.reorder(n, ho, wo, co, hi, wio, ci, wii)
97+
ci_wii = s.fuse(ci, wii)
98+
s.vectorize(ci_wii)
99+
100+
# Schedule 'Sum'
101+
s.compute_at(Sum, wio)
102+
Sum_axis = s.get_loops(Sum)
103+
s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3])
104+
ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3])
105+
# s.vectorize(ci_wii) # Doesn't work
106+
return s
107+
108+
109+
def STIR_schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
110+
"""Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""
111+
func = te.create_prim_func([ins, outs])
112+
s = tir.Schedule(func)
113+
Sum = s.get_block("sum")
114+
Avg = s.get_block("avg")
115+
116+
input_transform_fn = get_layout_transform_fn(input_layout)
117+
output_transform_fn = get_layout_transform_fn(output_layout)
118+
s.transform_layout(Sum, ("read", 0), input_transform_fn)
119+
s.transform_layout(Avg, ("write", 0), output_transform_fn)
120+
121+
# Schedule 'Avg'
122+
n, h, w, c = s.get_loops(Avg)
123+
co, ci = s.split(c, [None, 1024])
124+
cio, cii = s.split(ci, [None, 64])
125+
s.vectorize(cii)
126+
127+
# Schedule 'Sum'
128+
s.compute_at(Sum, cio)
129+
Sum_axis = s.get_loops(Sum)
130+
s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-3])
131+
# s.vectorize(Sum_axis[-3]) # Doesn't work
132+
return s
133+
134+
135+
def avg_pool2d_STIR_schedule(outs, ins, output_layout: str, input_layout: str):
136+
"""STIR based schedule"""
137+
if output_layout == "nhwc-8h2w32c2w-2d":
138+
return STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout, input_layout)
139+
if output_layout == "n11c-1024c-2d":
140+
return STIR_schedule_n11c_1024c(outs, ins, output_layout, input_layout)
141+
raise RuntimeError(f"Unexpected layout '{output_layout}'")

python/tvm/topi/hexagon/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Common hexagon specific utilities"""
19+
from tvm import te
20+
21+
22+
def n11c_1024c_2d(n, h, w, c):
23+
"""Return index map for n11c_1024 2d layout"""
24+
return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
25+
26+
27+
def n11c_1024c_1d(n, h, w, c):
28+
"""Return index map for n11c_1024 1d layout"""
29+
return [n, h, w, c // 1024, c % 1024]
30+
31+
32+
def nhwc_8h2w32c2w_2d(n, h, w, c):
33+
"""Return index map for nhwc_8h2w32c2w 2d layout"""
34+
return [n, h // 8, w // 4, c // 32, te.AXIS_SEPARATOR, h % 8, (w % 4) // 2, c % 32, w % 2]
35+
36+
37+
def nhwc_8h2w32c2w_1d(n, h, w, c):
38+
"""Return index map for nhwc_8h2w32c2w 1d layout"""
39+
return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]
40+
41+
42+
def get_layout_transform_fn(layout):
43+
"""Return index map function as per the layout string"""
44+
if layout == "nhwc-8h2w32c2w-2d":
45+
return nhwc_8h2w32c2w_2d
46+
if layout == "nhwc-8h2w32c2w-1d":
47+
return nhwc_8h2w32c2w_1d
48+
if layout == "n11c-1024c-2d":
49+
return n11c_1024c_2d
50+
if layout == "n11c-1024c-1d":
51+
return n11c_1024c_1d
52+
raise RuntimeError(f"Unexpected layout '{layout}'")

tests/python/contrib/test_hexagon/infrastructure.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
# pylint: disable=invalid-name
1718

1819
""" Hexagon testing infrastructure """
1920

@@ -228,3 +229,22 @@ def compute(n, ho, wo, ko, hi, wi, ki):
228229
)
229230

230231
return output_shape, compute
232+
233+
234+
def transform_numpy(arr_np, current_layout: str, new_layout: str):
235+
"""Reshape and transpose numpy array according to the specified layout"""
236+
if current_layout == "nhwc":
237+
if new_layout == "nhwc":
238+
return arr_np
239+
if new_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-1d"]:
240+
n, h, w, c = arr_np.shape
241+
return arr_np.reshape([n, h // 8, 8, w // 4, 2, 2, c // 32, 32]).transpose(
242+
0, 1, 3, 6, 2, 4, 7, 5
243+
)
244+
if new_layout in ["n11c-1024c-2d", "n11c-1024c-1d"]:
245+
n, h, w, c = arr_np.shape
246+
assert h == 1 and w == 1, "The size of h and w must be 1"
247+
return arr_np.reshape([n, 1, 1, c // 1024, 1024])
248+
249+
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
250+
raise RuntimeError(f"Unexpected current_layout '{current_layout}'")

0 commit comments

Comments
 (0)