Skip to content

Commit ff3ca9e

Browse files
committed
[microNPU] enable USMP
This commit enables USMP in the microNPU codegen and tests. Change-Id: Iafd7db8cd678f2b3cca8c06e5ea30e79a570faf9
1 parent 220f665 commit ff3ca9e

File tree

9 files changed

+236
-107
lines changed

9 files changed

+236
-107
lines changed

include/tvm/tir/usmp/utils.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,24 +251,24 @@ struct AllocatedPoolInfoNode : public Object {
251251
PoolInfo pool_info;
252252
/*! \brief The allocated size into this pool */
253253
Integer allocated_size;
254-
/*! \brief An optional associated pool Var*/
255-
Optional<Var> pool_var;
254+
/*! \brief An optional associated pool Var index of PrimFunc params*/
255+
Optional<Integer> pool_var_idx;
256256

257257
void VisitAttrs(tvm::AttrVisitor* v) {
258258
v->Visit("pool_info", &pool_info);
259259
v->Visit("allocated_size", &allocated_size);
260-
v->Visit("pool_var", &pool_var);
260+
v->Visit("pool_var_idx", &pool_var_idx);
261261
}
262262

263263
bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
264264
return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
265-
equal(pool_var, other->pool_var);
265+
equal(pool_var_idx, other->pool_var_idx);
266266
}
267267

268268
void SHashReduce(SHashReducer hash_reduce) const {
269269
hash_reduce(pool_info);
270270
hash_reduce(allocated_size);
271-
hash_reduce(pool_var);
271+
hash_reduce(pool_var_idx);
272272
}
273273

274274
static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
@@ -277,7 +277,8 @@ struct AllocatedPoolInfoNode : public Object {
277277

278278
class AllocatedPoolInfo : public ObjectRef {
279279
public:
280-
TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
280+
TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size,
281+
Integer pool_var_idx = Integer());
281282
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
282283
};
283284

python/tvm/micro/model_library_format.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,13 @@ def _build_function_memory_map(function_metadata):
177177
"""
178178
device_max_workspace = dict()
179179
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
180-
num_targets = len(main_func_metadata.workspace_sizes.items())
180+
main_targets = dict(main_func_metadata.workspace_sizes).keys()
181181
from tvm.driver import tvmc # pylint: disable=import-outside-toplevel
182182

183183
external_codegens = tvmc.composite_target.get_codegen_names()
184184
func_entries = []
185185
target_local_entries = dict()
186-
for i in range(num_targets):
187-
main_target = main_func_metadata.workspace_sizes.items()[i][0]
186+
for main_target in main_targets:
188187
device_max_workspace[main_target] = 0
189188
for func_name, finfo in function_metadata.items():
190189
if func_name == MAIN_FUNC_NAME_STR:
@@ -197,22 +196,18 @@ def _build_function_memory_map(function_metadata):
197196
# 2. BYOC operator implementations do not currently export useful FunctionInfo.
198197
if func_name == MAIN_FUNC_NAME_STR or not finfo.tir_primfuncs:
199198
continue
200-
assert (
201-
len(finfo.constant_sizes.items()) == num_targets
202-
), f"{func_name}: found {finfo.constant_sizes!r} vs {num_targets}"
203-
assert len(finfo.io_sizes.items()) == num_targets
204-
target = finfo.workspace_sizes.items()[i][0]
205-
workspace_size = finfo.workspace_sizes.items()[i][1]
206-
target_entry = {
207-
"device": int(target.kind.device_type),
208-
"workspace_size_bytes": int(workspace_size),
209-
}
210-
target_local_entries[func_name].append(target_entry)
211-
if workspace_size > device_max_workspace.get(target, 0):
212-
device_max_workspace[target] = workspace_size
213-
# TODO(Mousius) - Remove this massive hack when Targets are unified
214-
if target.kind.name in external_codegens:
215-
device_max_workspace[main_target] += int(workspace_size)
199+
if main_target in finfo.workspace_sizes.keys():
200+
workspace_size = finfo.workspace_sizes[main_target]
201+
target_entry = {
202+
"device": int(main_target.kind.device_type),
203+
"workspace_size_bytes": int(workspace_size),
204+
}
205+
target_local_entries[func_name].append(target_entry)
206+
if workspace_size > device_max_workspace.get(main_target, 0):
207+
device_max_workspace[main_target] = workspace_size
208+
# TODO(Mousius) - Remove this massive hack when Targets are unified
209+
if main_target.kind.name in external_codegens:
210+
device_max_workspace[main_target] += int(workspace_size)
216211

217212
for func_name, target_entries_ in target_local_entries.items():
218213
func_entry = {
@@ -222,15 +217,22 @@ def _build_function_memory_map(function_metadata):
222217
func_entries.append(func_entry)
223218

224219
target_main_entries = list()
225-
for i in range(num_targets):
226-
target = main_func_metadata.workspace_sizes.items()[i][0]
227-
main_func_local_workspace = main_func_metadata.workspace_sizes.items()[i][1]
228-
main_func_constants = main_func_metadata.constant_sizes.items()[i][1]
229-
main_func_io = main_func_metadata.io_sizes.items()[i][1]
220+
for main_target in main_targets:
221+
main_func_local_workspace = main_func_metadata.workspace_sizes[main_target]
222+
main_func_constants = (
223+
main_func_metadata.constant_sizes[main_target]
224+
if main_target in main_func_metadata.constant_sizes.keys()
225+
else 0
226+
)
227+
main_func_io = (
228+
main_func_metadata.io_sizes[main_target]
229+
if main_target in main_func_metadata.io_sizes.keys()
230+
else 0
231+
)
230232
target_main_entries.append(
231233
{
232-
"device": int(target.kind.device_type),
233-
"workspace_size_bytes": int(device_max_workspace[target])
234+
"device": int(main_target.kind.device_type),
235+
"workspace_size_bytes": int(device_max_workspace[main_target])
234236
+ int(main_func_local_workspace),
235237
"constants_size_bytes": int(main_func_constants),
236238
"io_size_bytes": int(main_func_io),

0 commit comments

Comments
 (0)