Skip to content

Commit 92c44e2

Browse files
committed
[microNPU][5] Convert Proposals to te.Schedules
Change-Id: I6771578f1007b8fea02e2dec7d0c797a6ef6aa5e
1 parent cb7f773 commit 92c44e2

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed

python/tvm/contrib/ethosu/cascader/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@
3636
from .device_config import EthosuDeviceConfig
3737
from .tensor_config import TensorConfigState, MemoryRegion, TensorConfig
3838
from .plan import Plan
39+
from .scheduler import apply_proposal, cascade
3940
from .cascader_options import CascaderOptions
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Scheduler for cascader which converts Proposals into Schedules."""
18+
from typing import Tuple, List, Dict, DefaultDict
19+
from collections import defaultdict
20+
import numpy as np
21+
22+
from tvm import te
23+
from tvm import tir
24+
from .cascader_options import CascaderOptions
25+
from .graph import CascaderGraph, Part, Tensor, TESubgraph
26+
from .tensor_config import MemoryRegion
27+
from .proposal import Proposal
28+
from .proposal_generator import generate_proposals
29+
from .graph import create_cascader_graph
30+
from .device_config import EthosuDeviceConfig
31+
32+
33+
def tile_nd(
34+
sch: te.Schedule, tensor: te.Tensor, tile: Tuple[int, ...]
35+
) -> Tuple[List[tir.IterVar], List[tir.IterVar]]:
36+
"""Scheduling utility to perform N-dimensional tiling.
37+
38+
Parameters
39+
----------
40+
sch : te.Schedule
41+
The schedule to apply the tiling to.
42+
tensor : te.Tensor
43+
The tensor to apply the tiling to.
44+
tile : Tuple[int, ...]
45+
The N-dimensional tile size.
46+
47+
Returns
48+
-------
49+
outer_indices : List[tir.IterVar]
50+
The outer iteration variables.
51+
inner_indices : List[tir.IterVar]
52+
The inner iteration variables.
53+
54+
"""
55+
outer_indices = []
56+
inner_indices = []
57+
for i, size in enumerate(tile):
58+
outer, inner = sch[tensor].split(tensor.op.axis[i], size)
59+
outer_indices.append(outer)
60+
inner_indices.append(inner)
61+
62+
sch[tensor].reorder(*outer_indices, *inner_indices)
63+
return outer_indices, inner_indices
64+
65+
66+
def stripe_part(
67+
part: Part, stripe_shape: Tuple[int, ...], sch: te.Schedule
68+
) -> Tuple[te.Stage, tir.IterVar]:
69+
"""Apply a striping schedule to the TE subgraph represented by a Part."""
70+
te_subgraph = part.subgraph
71+
te_output_tensor = te_subgraph.output_tensor
72+
outer_indices, _ = tile_nd(sch, te_output_tensor, stripe_shape)
73+
g = sch.create_group(
74+
outputs=te_output_tensor.op.input_tensors,
75+
inputs=te_subgraph.input_tensors,
76+
include_inputs=False,
77+
)
78+
g.compute_at(sch[te_output_tensor], outer_indices[-1])
79+
for ax in outer_indices:
80+
sch[te_output_tensor].unroll(ax)
81+
82+
return sch[te_output_tensor], outer_indices[-1]
83+
84+
85+
def cascade_part(
86+
part: Part, stripe_stage: te.Stage, stripe_axis: tir.IterVar, sch: te.Schedule
87+
) -> None:
88+
"""Schedule a Part into a cascade indicated by a stripe Stage."""
89+
te_subgraph = part.subgraph
90+
g = sch.create_group(
91+
outputs=te_subgraph.output_tensor, inputs=te_subgraph.input_tensors, include_inputs=False
92+
)
93+
g.compute_at(stripe_stage, stripe_axis)
94+
95+
96+
def update_readers(part: Part, readers: DefaultDict[te.Tensor, List[te.Tensor]]) -> None:
97+
"""Update a dictionary which stores the te.Tensors that need to be read in order to produce a given te.Tensor."""
98+
visited = set()
99+
100+
def _visit(tensor):
101+
if tensor is not visited and tensor not in part.subgraph.input_tensors:
102+
visited.add(tensor)
103+
for input_tensor in tensor.op.input_tensors:
104+
readers[input_tensor].append(tensor)
105+
_visit(input_tensor)
106+
107+
_visit(part.subgraph.output_tensor)
108+
109+
110+
def apply_proposal(proposal: Proposal, sch: te.Schedule) -> None:
111+
"""Apply a Proposal to a Schedule, converting all the Plans into TE scheduling instructions.
112+
113+
Note that the Schedule is mutated in-place.
114+
115+
Parameters
116+
----------
117+
proposal : Proposal
118+
The Proposal to apply to the Schedule.
119+
sch : te.Schedule
120+
The Schedule to apply to Proposal to.
121+
122+
"""
123+
for plan in proposal.plans:
124+
output_tensor_config = plan.output_config
125+
output_tensor = output_tensor_config.tensor
126+
output_part = output_tensor.producers[0]
127+
if output_part.in_line:
128+
continue
129+
stripe_config = output_tensor_config.stripe_configs[0]
130+
stripe_shape = [int(x) for x in stripe_config.shape]
131+
stripe_stage, stripe_axis = stripe_part(output_part, stripe_shape, sch)
132+
copy_te_tensors = []
133+
readers = defaultdict(list)
134+
for part in plan.part_group:
135+
if part != output_part:
136+
cascade_part(part, stripe_stage, stripe_axis, sch)
137+
138+
update_readers(part, readers)
139+
for i, input_tensor in enumerate(part.input_tensors):
140+
tensor_config = plan.tensor_configs[input_tensor]
141+
if tensor_config.home_region != tensor_config.copy_region:
142+
copy_te_tensors.append(part.subgraph.input_tensors[i])
143+
144+
for te_tensor in copy_te_tensors:
145+
copy_stage = sch.cache_read(te_tensor, "global", readers[te_tensor])
146+
sch[copy_stage].compute_at(stripe_stage, stripe_axis)
147+
148+
149+
def create_home_map(
150+
graph: CascaderGraph,
151+
io_region: MemoryRegion,
152+
constant_region: MemoryRegion,
153+
working_regions: List[MemoryRegion],
154+
) -> Dict[Tensor, List[MemoryRegion]]:
155+
"""Create a map between Tensors and the MemoryRegions they can be homed in."""
156+
home_map = {}
157+
for tensor in graph.tensor_order:
158+
if tensor.is_constant:
159+
home_map[tensor] = [constant_region]
160+
elif tensor in graph.input_tensors or tensor in graph.output_tensors:
161+
home_map[tensor] = [io_region]
162+
else:
163+
home_map[tensor] = working_regions
164+
165+
return home_map
166+
167+
168+
def choose_proposal(proposals: List[Proposal], cascade_region: MemoryRegion):
169+
"""Choose the best performing Proposal that doesn't overflow the cascade region."""
170+
proposal_choice = proposals[0]
171+
for proposal in reversed(proposals):
172+
if proposal.memory_usage < cascade_region.size:
173+
proposal_choice = proposal
174+
break
175+
176+
return proposal_choice
177+
178+
179+
def cascade(
180+
sch: te.Schedule,
181+
te_graph: TESubgraph,
182+
const_dict: Dict[int, np.ndarray],
183+
options: CascaderOptions,
184+
io_region: MemoryRegion,
185+
constant_region: MemoryRegion,
186+
working_regions: List[MemoryRegion],
187+
device_config: EthosuDeviceConfig,
188+
) -> None:
189+
"""Schedule a Tensor Expression graph using the technique of 'cascading'.
190+
191+
'Cascading' is a technique whereby operations are split into smaller
192+
dependent tiles ('stripes') which can then execute in an interleaved
193+
fashion. This allows for operations to execute together rather than
194+
sequentially which can reduce intermediate memory requirements and in
195+
certain cases improve performance.
196+
197+
For more detail on 'cascading' as well as how it is implemented, refer to
198+
the RFC here: https://github.com/apache/tvm-rfcs/pull/37.
199+
200+
Parameters
201+
----------
202+
sch : te.Schedule
203+
The Schedule to apply the cascading to.
204+
te_graph : TESubgraph
205+
The Tensor Expression graph from which the Schedule was created.
206+
const_dict : Dict[int, np.ndarray]
207+
A dictionary mapping input index to constant data if that input is
208+
to be a constant.
209+
options : CascaderOptions
210+
Configuration options for the cascading scheduler.
211+
io_region : MemoryRegion
212+
The MemoryRegion in which input/output tensors should reside.
213+
constant_region : MemoryRegion
214+
The MemoryRegion in which constants should reside.
215+
working_regions : List[MemoryRegion]
216+
The MemoryRegions in which intermediate working tensors can reside. The
217+
cascading scheduler will select which MemoryRegion to per tensor.
218+
device_config : EthosuDeviceConfig
219+
Target device configuration.
220+
221+
"""
222+
assert options.cascade_region in working_regions
223+
# First convert the Tensor Expression graph into a CascaderGraph
224+
casc_graph = create_cascader_graph(te_graph, const_dict, device_config)
225+
# Then create a mapping between Tensors and their possible memory homes
226+
home_map = create_home_map(casc_graph, io_region, constant_region, working_regions)
227+
# Generate Proposals for Pareto-optimal ways to cascade the CascaderGraph
228+
proposals = generate_proposals(casc_graph, home_map, options)
229+
# Select the best Proposal subject to the memory constraints
230+
proposal_choice = choose_proposal(proposals, options.cascade_region)
231+
# Apply the selected Proposal to the Tensor Expression Schedule
232+
apply_proposal(proposal_choice, sch)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import pytest
18+
19+
import tvm.contrib.ethosu.cascader as cs
20+
21+
22+
def test_cascade(SRAM, FLASH, TwoConv2DWithSliceTE, TwoConv2DTE, MobileNetv1StartTE, MobileNetv1TE):
23+
fixtures = [
24+
TwoConv2DTE,
25+
TwoConv2DWithSliceTE,
26+
MobileNetv1StartTE,
27+
MobileNetv1TE,
28+
]
29+
device_config = cs.EthosuDeviceConfig("ethos-u55-256")
30+
for sch, te_graph, const_dict in fixtures:
31+
options = cs.CascaderOptions(
32+
cascade_region=SRAM,
33+
max_proposals=64,
34+
stripe_factors=4,
35+
max_plan_size=10,
36+
always_copy_size=1024,
37+
)
38+
cs.cascade(sch, te_graph, const_dict, options, SRAM, FLASH, [SRAM], device_config)
39+
40+
41+
if __name__ == "__main__":
42+
pytest.main([__file__])

0 commit comments

Comments
 (0)