Skip to content

Commit 75ec1cf

Browse files
authored
[TVMC] Workspace Pools Parameters (#11427)
* [TVMC] Workspace Pools Parameters Attributes from tvmc are now passable into the created PoolInfo objects inside WorkspaceMemoryPools. This is passed in to relay.build that get attached to IRModule attribute. * [TVMC] Workspace Pools Parameters Address comments, fix linting. Testing improved. Change-Id: Iea79329b6b9ec1cbc51e5c293449bf6dd43b00c5 * [TVMC] Workspace Pools Parameters Update workspace pools test naming Change-Id: Ib698d6248be1e6f44340f27db3641c985bc5c5d8 * [TVMC] Workspace Pools Parameters Add test for parameter overrides. Change-Id: I67d5470dcfbfbc9ab27f34e20a9269d2070193ca * [TVMC] Workspace Pools Parameters Rebasing over #10189 Updates to the way a WorkspaceMemoryPool object is created Change-Id: I1f0e1d240343af311ddb3ed5c564cc1ab329f463 * [TVMC] Workspace Pools Parameters Fix linting, fix CI Change-Id: If75f8709ac4ad925655eca54b3e5c1bb09d025e8 * [TVMC] Workspace Pools Parameters Add mcpu and mattr to target registry for cmsis-nn Change-Id: I15257b8d01624c071c738cab6d12ecb84ed6cb16 * [TVMC] Workspace Pools Parameters Added test for override on single pool when multiple pools are present Updated functionality of parsing multiple attributes Change-Id: I2c0745051b7a923dd7f75040bfb89bbc99376a11
1 parent 6eb3a1f commit 75ec1cf

File tree

8 files changed

+717
-5
lines changed

8 files changed

+717
-5
lines changed

include/tvm/ir/memory_pools.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ struct PoolInfoNode : public Object {
6565

6666
void VisitAttrs(tvm::AttrVisitor* v) {
6767
v->Visit("pool_name", &pool_name);
68+
v->Visit("targets", &targets);
6869
v->Visit("size_hint_bytes", &size_hint_bytes);
6970
v->Visit("clock_frequency_hz", &clock_frequency_hz);
7071
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);

python/tvm/driver/tvmc/compiler.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tvm import autotvm, auto_scheduler
2727
from tvm import relay
2828
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
29+
from tvm.ir.memory_pools import WorkspaceMemoryPools
2930
from tvm.target import Target
3031
from tvm.relay.backend import Executor, Runtime
3132

@@ -37,6 +38,7 @@
3738
from .pass_list import parse_pass_list_str
3839
from .transform import convert_graph_layout
3940
from .shape_parser import parse_shape_string
41+
from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
4042

4143
# pylint: disable=invalid-name
4244
logger = logging.getLogger("TVMC")
@@ -142,10 +144,11 @@ def add_compile_parser(subparsers, _, json_params):
142144
default="default",
143145
help="The output module name. Defaults to 'default'.",
144146
)
145-
146147
for one_entry in json_params:
147148
parser.set_defaults(**one_entry)
148149

150+
generate_workspace_pools_args(parser)
151+
149152

150153
def drive_compile(args):
151154
"""Invoke tvmc.compiler module with command line arguments
@@ -161,6 +164,7 @@ def drive_compile(args):
161164
Zero if successfully completed
162165
163166
"""
167+
164168
if not os.path.isfile(args.FILE):
165169
raise TVMCException(
166170
f"Input file '{args.FILE}' doesn't exist, is a broken symbolic link, or a directory."
@@ -170,6 +174,9 @@ def drive_compile(args):
170174

171175
dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None
172176

177+
additional_targets = reconstruct_target_args(args)
178+
workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets)
179+
173180
compile_model(
174181
tvmc_model,
175182
args.target,
@@ -186,8 +193,11 @@ def drive_compile(args):
186193
desired_layout=args.desired_layout,
187194
disabled_pass=args.disabled_pass,
188195
pass_context_configs=args.pass_config,
189-
additional_target_options=reconstruct_target_args(args),
190196
mod_name=args.module_name,
197+
additional_target_options=additional_targets,
198+
workspace_pools=(
199+
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
200+
),
191201
)
192202

193203
return 0
@@ -212,6 +222,7 @@ def compile_model(
212222
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
213223
use_vm: bool = False,
214224
mod_name: Optional[str] = "default",
225+
workspace_pools: Optional[WorkspaceMemoryPools] = None,
215226
):
216227
"""Compile a model from a supported framework into a TVM module.
217228
@@ -263,6 +274,9 @@ def compile_model(
263274
Whether to use the VM to compile the model as opposed to the graph executor
264275
mod_name: str, optional
265276
The module name
277+
workspace_pools: WorkspaceMemoryPools, optional
278+
Specification of WorkspacePoolInfo objects to be used as workspace memory in the
279+
compilation.
266280
267281
Returns
268282
-------
@@ -313,6 +327,7 @@ def compile_model(
313327
params=params,
314328
use_vm=use_vm,
315329
mod_name=mod_name,
330+
workspace_pools=workspace_pools,
316331
)
317332
else:
318333
with autotvm.apply_history_best(tuning_records):
@@ -328,6 +343,7 @@ def compile_model(
328343
params=params,
329344
use_vm=use_vm,
330345
mod_name=mod_name,
346+
workspace_pools=workspace_pools,
331347
)
332348
else:
333349
with tvm.transform.PassContext(
@@ -342,6 +358,7 @@ def compile_model(
342358
params=params,
343359
use_vm=use_vm,
344360
mod_name=mod_name,
361+
workspace_pools=workspace_pools,
345362
)
346363

347364
# Generate output dump files with sources
@@ -380,6 +397,7 @@ def build(
380397
params: Dict[str, tvm.nd.NDArray],
381398
use_vm: bool,
382399
mod_name: str,
400+
workspace_pools: Optional[WorkspaceMemoryPools],
383401
):
384402
"""
385403
Builds the model with the provided executor.
@@ -408,7 +426,13 @@ def build(
408426
return relay.vm.compile(mod, target=tvm_target, params=params)
409427
logger.debug("building with relay build")
410428
return relay.build(
411-
mod, target=tvm_target, executor=executor, runtime=runtime, params=params, mod_name=mod_name
429+
mod,
430+
target=tvm_target,
431+
executor=executor,
432+
runtime=runtime,
433+
params=params,
434+
mod_name=mod_name,
435+
workspace_memory_pools=workspace_pools,
412436
)
413437

414438

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Functions for processing dynamic workspace pool TVMC args
19+
"""
20+
21+
22+
import logging
23+
import re
24+
25+
from tvm.driver.tvmc import TVMCException
26+
from tvm.target import Target
27+
from tvm.ir.memory_pools import PoolInfoProperties, WorkspaceMemoryPools, WorkspacePoolInfo
28+
29+
30+
# pylint: disable=invalid-name
31+
logger = logging.getLogger("TVMC")
32+
33+
34+
def generate_workspace_pools_args(parser):
35+
"""Generates arguments for each Workspace Pools's options"""
36+
parser.add_argument(
37+
"--workspace-pools",
38+
help="""The name of the memory pool
39+
Example usage: --workspace-pools=flash""",
40+
)
41+
parser.add_argument(
42+
"--workspace-pools-targets",
43+
help="""The name of the targets specified for the memory pool
44+
Example usage: --workspace-pools-targets=flash:llvm""",
45+
action="append",
46+
)
47+
parser.add_argument(
48+
"--workspace-pools-size-hint-bytes",
49+
nargs="?",
50+
help="""The expected size hint to be used by the allocator.
51+
Example usage: --workspace-pools-size-hint-bytes=flash:8""",
52+
action="append",
53+
)
54+
parser.add_argument(
55+
"--workspace-pools-clock-frequency-hz",
56+
nargs="?",
57+
help="""The clock frequency that the memory pool runs at in Hz.
58+
Example usage: --workspace-pools-clock-frequency-hz=flash:70000000""",
59+
action="append",
60+
)
61+
parser.add_argument(
62+
"--workspace-pools-read-bandwidth-bytes-per-cycle",
63+
nargs="?",
64+
help="""The read bandwidth of the memory pool in bytes/cycle.
65+
Example usage: --workspace-pools-read-bandwidth-bytes-per-cycle=flash:4""",
66+
action="append",
67+
)
68+
parser.add_argument(
69+
"--workspace-pools-write-bandwidth-bytes-per-cycle",
70+
nargs="?",
71+
help="""The write bandwidth of the memory pool in bytes/cycle.
72+
Example usage: --workspace-pools-write-bandwidth-bytes-per-cycle=flash:8""",
73+
action="append",
74+
)
75+
parser.add_argument(
76+
"--workspace-pools-read-latency-cycles",
77+
nargs="?",
78+
help="""The read latency of the memory pool in cycles.
79+
Example usage: --workspace-pools-read-latency-cycles=flash:4""",
80+
action="append",
81+
)
82+
parser.add_argument(
83+
"--workspace-pools-write-latency-cycles",
84+
nargs="?",
85+
help="""The write latency of the memory pool in cycles.
86+
Example usage: --workspace-pools-write-latency-cycles=flash:8""",
87+
action="append",
88+
)
89+
parser.add_argument(
90+
"--workspace-pools-target-burst-bytes",
91+
help="""The burst length of the memory pool in bytes per target.
92+
Example usage: --workspace-pools-target-burst-bytes=flash:accel:1""",
93+
action="append",
94+
)
95+
96+
97+
def _parse_target_burst(attr_str, pool_name):
98+
if pool_name not in attr_str:
99+
return {}
100+
101+
return {target: int(attr_str[pool_name][target]) for target in attr_str[pool_name]}
102+
103+
104+
def _parse_target_string(attr_str, targets, pool_name):
105+
if attr_str is None:
106+
raise TVMCException(f'No target specified for Workspace Pool "{pool_name}"')
107+
108+
target_name = [re.split(",", attr_str)]
109+
matched_targets = [
110+
target
111+
for target in targets
112+
if any(target.kind.name in target_string_match for target_string_match in target_name[0])
113+
]
114+
if not matched_targets:
115+
raise TVMCException(f'Workspace Pool "{pool_name}" using undefined Target "{target_name}"')
116+
return matched_targets
117+
118+
119+
def _split_pools_to_pool_names(attr_str):
120+
return re.split(",", attr_str) if attr_str else []
121+
122+
123+
def _parse_target_attributes_of_pool_name(attr_str, targets):
124+
if not targets or attr_str is None:
125+
return {}
126+
127+
target_attributes = {}
128+
for pool_values in attr_str:
129+
pool_name, target_name, target_value = re.split(":", pool_values)
130+
if pool_name not in target_attributes:
131+
target_attributes[pool_name] = {}
132+
133+
matched_targets = [target for target in targets if target_name == target.kind.name]
134+
if matched_targets:
135+
target_attributes[pool_name][matched_targets[0]] = target_value
136+
else:
137+
raise TVMCException(
138+
"The workspace pool target specification "
139+
"needs to contain a subset of the same TVM "
140+
"targets as when specifying targets to use."
141+
)
142+
return target_attributes
143+
144+
145+
def _parse_attribute_of_pool_name(attr_str):
146+
return dict(pool.split(":", maxsplit=1) for pool in attr_str) if attr_str else {}
147+
148+
149+
def workspace_pools_recombobulate(parsed, targets, extra_target):
150+
"""Reconstructs the Workspace Pools args and returns a WorkspaceMemoryPool object"""
151+
WORKSPACE_POOL_PARAMS = [
152+
"workspace_pools_size_hint_bytes",
153+
"workspace_pools_targets",
154+
"workspace_pools_clock_frequency_hz",
155+
"workspace_pools_read_bandwidth_bytes_per_cycle",
156+
"workspace_pools_write_bandwidth_bytes_per_cycle",
157+
"workspace_pools_read_latency_cycles",
158+
"workspace_pools_write_latency_cycles",
159+
]
160+
WORKSPACE_POOL_TARGET_PARAMS = [
161+
"workspace_pools_target_burst_bytes",
162+
]
163+
164+
# Load extra targets from CLI
165+
additional_targets = []
166+
167+
for t in extra_target:
168+
additional_targets.append(Target(t["raw"], host=targets[0].host or targets[0]))
169+
170+
target = targets + additional_targets
171+
if targets[0].host:
172+
target.append(targets[0].host)
173+
174+
workspace_pools = _split_pools_to_pool_names(parsed.workspace_pools)
175+
if not workspace_pools:
176+
return None
177+
178+
parse_attribute_to_pool_name = {
179+
workspace_pool_param: _parse_attribute_of_pool_name(getattr(parsed, workspace_pool_param))
180+
for workspace_pool_param in WORKSPACE_POOL_PARAMS
181+
}
182+
parse_target_burst_bytes_to_pool = {
183+
workspace_pool_param: _parse_target_attributes_of_pool_name(
184+
getattr(parsed, workspace_pool_param), targets
185+
)
186+
for workspace_pool_param in WORKSPACE_POOL_TARGET_PARAMS
187+
}
188+
189+
return WorkspaceMemoryPools(
190+
[
191+
WorkspacePoolInfo(
192+
pool_name,
193+
targets=_parse_target_string(
194+
parse_attribute_to_pool_name["workspace_pools_targets"].get(pool_name),
195+
target,
196+
pool_name,
197+
),
198+
pool_info_properties=PoolInfoProperties(
199+
size_hint_bytes=int(
200+
parse_attribute_to_pool_name["workspace_pools_size_hint_bytes"].get(
201+
pool_name, -1
202+
)
203+
),
204+
clock_frequency_hz=int(
205+
parse_attribute_to_pool_name["workspace_pools_clock_frequency_hz"].get(
206+
pool_name, -1
207+
)
208+
),
209+
read_bandwidth_bytes_per_cycle=int(
210+
parse_attribute_to_pool_name[
211+
"workspace_pools_read_bandwidth_bytes_per_cycle"
212+
].get(pool_name, -1)
213+
),
214+
write_bandwidth_bytes_per_cycle=int(
215+
parse_attribute_to_pool_name[
216+
"workspace_pools_write_bandwidth_bytes_per_cycle"
217+
].get(pool_name, -1)
218+
),
219+
read_latency_cycles=int(
220+
parse_attribute_to_pool_name["workspace_pools_read_latency_cycles"].get(
221+
pool_name, 0
222+
)
223+
),
224+
write_latency_cycles=int(
225+
parse_attribute_to_pool_name["workspace_pools_write_latency_cycles"].get(
226+
pool_name, 0
227+
)
228+
),
229+
target_burst_bytes=_parse_target_burst(
230+
parse_target_burst_bytes_to_pool["workspace_pools_target_burst_bytes"],
231+
pool_name,
232+
),
233+
),
234+
)
235+
for pool_name in workspace_pools
236+
]
237+
)

python/tvm/ir/memory_pools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class WorkspaceMemoryPools(Object):
189189

190190
def __init__(
191191
self,
192-
pools: List[PoolInfo],
192+
pools: List[WorkspacePoolInfo],
193193
):
194194
self.__init_handle_by_constructor__(
195195
_ffi_api.WorkspaceMemoryPools, pools # type: ignore # pylint: disable=no-member

src/relay/backend/contrib/cmsisnn/target.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
3232

3333
TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
3434
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
35-
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
35+
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
36+
.add_attr_option<Array<String>>("mattr")
37+
.add_attr_option<String>("mcpu");
3638

3739
} // namespace cmsisnn
3840
} // namespace contrib

0 commit comments

Comments
 (0)