Skip to content

Commit 35ddc1a

Browse files
committed
[Ethos-U] Replace ethos-u.constants with AllocateConst
Previously, constants for ethos-u were tracked using a function attribute `"ethos-u.constants"`. This predates the introduction of `AllocateConst`, and had comments indicating that it should be replaced with `AllocateConst` when possible. To minimize impact to existing passes, this commit preserves the `"ethos-u.constants"` attribute during ethosu-specific lowering passes. The attribute is converted to `AllocateConst` at the end of the `lower_ethosu` pass, just prior to lowering with the usual TIR passes.
1 parent 3c23865 commit 35ddc1a

File tree

19 files changed

+117
-76
lines changed

19 files changed

+117
-76
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
1818
from collections import defaultdict
19-
from typing import List, Callable
19+
from typing import List, Callable, Dict
2020

2121
from ethosu.vela import api as vapi
2222
import tvm
@@ -532,6 +532,34 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
532532
return tir_mod
533533

534534

535+
def collect_consts(mod: tvm.IRModule) -> Dict[tvm.tir.Var, tvm.nd.NDArray]:
536+
"""Collect any AllocateCont
537+
538+
Parameters
539+
----------
540+
mod: tvm.IRModule
541+
542+
The module to inspect.
543+
544+
Returns
545+
-------
546+
const_dict: Dict[tvm.tir.Var, tvm.nd.NDArray]
547+
548+
A map from buffer var to NDArray, from AllocateConst nodes in
549+
the module
550+
"""
551+
constants = {}
552+
553+
def _visit(stmt):
554+
if isinstance(stmt, tvm.tir.AllocateConst):
555+
constants[stmt.buffer_var] = stmt.data
556+
557+
for func in mod.functions.values():
558+
tvm.tir.stmt_functor.post_order_visit(func.body, _visit)
559+
560+
return constants
561+
562+
535563
@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
536564
def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact:
537565
"""
@@ -551,13 +579,12 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
551579
for the microNPU
552580
"""
553581
symbol = str(primfunc.attrs["global_symbol"])
554-
const_dict = primfunc.attrs["ethos-u.constants"]
555582
tir_mod = tvm.IRModule()
556583
tir_mod[symbol] = primfunc
557584

558-
const_dict_np = dict()
559-
for buffer_var in const_dict.keys():
560-
const_dict_np[buffer_var] = const_dict[buffer_var].numpy()
585+
const_dict_np = {
586+
buffer_var: ndarray.numpy() for buffer_var, ndarray in collect_consts(tir_mod).items()
587+
}
561588

562589
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np)
563590
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ def lower_ethosu(sch, args, const_dict, name="main"):
4949
-------
5050
mod : tvm.IRModule
5151
The lowered TIR module.
52-
const_dict : dict of int to numpy.ndarray
53-
The modified constant dictionary.
54-
5552
"""
5653
if not isinstance(args, list):
5754
args = list(args.inputs) + list(args.outputs)
@@ -101,8 +98,8 @@ def lower_ethosu(sch, args, const_dict, name="main"):
10198

10299
mod = tvm.tir.transform.RemoveNoOp()(mod)
103100
mod = ethosu_passes.AnnotateAllocates()(mod)
104-
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
105-
return mod, const_dict
101+
mod = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
102+
return mod
106103

107104

108105
def lower_to_te(prim_func):
@@ -200,15 +197,11 @@ def __init__(self, scheduler):
200197
def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
201198
"""Lower NPU functions to TIR."""
202199

203-
tir_mod, const_dict = _lower_to_tir(func, self.scheduler)
204-
205-
for param in const_dict.keys():
206-
const_dict[param] = tvm.nd.array(const_dict[param])
200+
tir_mod = _lower_to_tir(func, self.scheduler)
207201

208202
compiler_name = "ethos-u"
209203
primfunc = tir_mod["main"]
210204
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
211-
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
212205
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
213206
return primfunc
214207

@@ -233,14 +226,11 @@ def _lower_to_tir(func, cascader=None):
233226
-------
234227
mod : tvm.IRModule
235228
The lowered TIR module.
236-
consts : dict of int to numpy.ndarray
237-
A dict of the extracted constants keyed by their param index.
238-
239229
"""
240230
func, consts = extract_constants(func)
241231
mod = tvm.IRModule.from_expr(func)
242232
func = relay.transform.InferType()(mod)["main"]
243233
cached_func = lower_to_te(func)
244234
s = schedule(cached_func, consts, cascader)
245-
mod, consts = lower_ethosu(s, cached_func, consts)
246-
return mod, consts
235+
mod = lower_ethosu(s, cached_func, consts)
236+
return mod

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -911,31 +911,35 @@ def _ftransform(f, mod, ctx):
911911
def CreatePrimFuncWithoutConstants(const_dict):
912912
"""
913913
This pass will remove arguments that are constants
914-
from PrimFunc Args. These should be replaced properly
915-
with tir.allocate_const when it becomes available.
914+
from PrimFunc Args, replacing them with tir.allocate_const.
916915
917916
It also modifies the constant dictionary to
918917
rewrite the keys as the actual tir.Vars that are params
919918
rather than the index because this pass removes PrimFunc
920919
arguments that represent constants.
921920
"""
922921

923-
new_const_dict = dict()
924-
925922
def _ftransform(f, mod, ctx):
926923
new_params = list()
927924
new_buffer_map = dict()
928-
for param_idx in const_dict.keys():
929-
# We are using buffer_var to key the constants as
930-
# PrimFunc params of constants will be removed.
931-
new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]
925+
926+
new_body = f.body
927+
932928
for i, param in enumerate(f.params):
933-
if i not in const_dict.keys():
929+
if i in const_dict:
930+
const_np = const_dict[i]
931+
const_ndarray = tvm.nd.array(const_np, device=tvm.cpu())
932+
buf = f.buffer_map[param]
933+
new_body = tvm.tir.AllocateConst(
934+
buf.data, buf.dtype, buf.shape, const_ndarray, new_body
935+
)
936+
else:
934937
new_params.append(param)
935938
new_buffer_map[param] = f.buffer_map[param]
939+
936940
return tvm.tir.PrimFunc(
937941
new_params,
938-
f.body,
942+
new_body,
939943
f.ret_type,
940944
new_buffer_map,
941945
f.attrs,
@@ -947,7 +951,7 @@ def _create_primfunc_without_constants(mod):
947951
_ftransform, opt_level=0, name="tir.contrib.ethos-u.CreatePrimFuncWithoutConstants"
948952
)
949953
mod = transform_func(mod)
950-
return mod, new_const_dict
954+
return mod
951955

952956
return _create_primfunc_without_constants
953957

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,12 +2347,21 @@ def partition_for_ethosu(
23472347
mod["main"] = bind_params_by_name(mod["main"], params)
23482348

23492349
pattern = relay.op.contrib.get_pattern_table("ethos-u")
2350-
mod = relay.transform.InferType()(mod)
2351-
mod = relay.transform.MergeComposite(pattern)(mod)
2352-
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
2353-
mod = relay.transform.MergeCompilerRegions()(mod)
2354-
mod = relay.transform.InferType()(mod)
2355-
mod = relay.transform.PartitionGraph(mod_name)(mod)
2356-
mod = relay.transform.InferType()(mod)
2357-
mod = preprocess.preprocess_ext_io()(mod)
2350+
2351+
seq = tvm.ir.transform.Sequential(
2352+
[
2353+
relay.transform.InferType(),
2354+
relay.transform.MergeComposite(pattern),
2355+
relay.transform.AnnotateTarget("ethos-u"),
2356+
relay.transform.MergeCompilerRegions(),
2357+
relay.transform.InferType(),
2358+
relay.transform.PartitionGraph(mod_name),
2359+
relay.transform.InferType(),
2360+
preprocess.preprocess_ext_io(),
2361+
],
2362+
name="partition_for_ethosu",
2363+
)
2364+
mod = seq(mod)
2365+
2366+
mod.show(name="end_of_partition_for_ethosu")
23582367
return mod

src/relay/backend/contrib/ethosu/codegen.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,6 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
305305
Array<CompilationArtifact> compile_artifacts;
306306
for (const auto& kv : mod->functions) {
307307
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(kv.second);
308-
Optional<Map<Integer, runtime::NDArray>> params =
309-
prim_func->GetAttr<Map<Integer, runtime::NDArray>>("ethos-u.constants");
310-
ICHECK(params) << "microNPU params should be present";
311308
auto primfunc_to_artifact_pf =
312309
tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact");
313310
ICHECK(primfunc_to_artifact_pf);

src/target/source/codegen_c_host.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
426426
bool emit_fwd_func_decl = true;
427427

428428
std::unordered_set<std::string> devices;
429-
if (mod->GetAttr<Map<GlobalVar, String>>("device_contexts") != nullptr) {
430-
Map<GlobalVar, String> device_contexts =
431-
mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value();
429+
if (auto opt = mod->GetAttr<Map<GlobalVar, String>>("device_contexts")) {
430+
auto device_contexts = opt.value();
432431
for (auto const& context : device_contexts) {
433432
devices.insert(context.second.data());
434433
}

src/tir/transforms/split_host_device.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/tir/op.h>
3232
#include <tvm/tir/stmt_functor.h>
3333
#include <tvm/tir/transform.h>
34+
#include <tvm/tir/usmp/utils.h>
3435

3536
#include <unordered_map>
3637

@@ -43,8 +44,9 @@ namespace tir {
4344

4445
class HostDeviceSplitter : public StmtMutator {
4546
public:
46-
explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply)
47-
: device_mod_(device_mod), var_supply_(var_supply) {}
47+
explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply,
48+
Map<String, ObjectRef> extra_attrs)
49+
: device_mod_(device_mod), var_supply_(var_supply), extra_attrs_(extra_attrs) {}
4850

4951
Stmt VisitStmt_(const AttrStmtNode* op) final {
5052
if (op->attr_key == tvm::attr::kTarget) {
@@ -78,7 +80,9 @@ class HostDeviceSplitter : public StmtMutator {
7880
PrimFunc device_func(params, body);
7981
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
8082
{tir::attr::kNoAlias, Bool(true)},
81-
{tir::attr::kIsGlobalFunc, Bool(true)}});
83+
{tir::attr::kIsGlobalFunc, Bool(true)},
84+
{tir::attr::kIsEntryFunc, Bool(false)}});
85+
device_func = WithAttrs(std::move(device_func), extra_attrs_);
8286

8387
(*device_mod_)->Add(kernel_symbol_global, device_func);
8488
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
@@ -90,11 +94,18 @@ class HostDeviceSplitter : public StmtMutator {
9094
IRModule* device_mod_;
9195
// Generate new GlobalVar for the kernel
9296
std::function<GlobalVar()> var_supply_;
97+
// Extra attrs to be added to extracted kernels.
98+
Map<String, ObjectRef> extra_attrs_;
9399
};
94100

95101
PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod,
96102
std::function<GlobalVar()> var_supply) {
97-
HostDeviceSplitter splitter(device_mod, var_supply);
103+
Map<String, ObjectRef> extra_attrs;
104+
if (auto opt = func->GetAttr<ObjectRef>(tvm::attr::kPoolArgs)) {
105+
extra_attrs.Set(tvm::attr::kPoolArgs, opt.value());
106+
}
107+
108+
HostDeviceSplitter splitter(device_mod, var_supply, extra_attrs);
98109

99110
if (auto body = splitter(func->body); !body.same_as(func->body)) {
100111
func.CopyOnWrite()->body = body;

tests/python/contrib/test_ethosu/test_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_lower_to_tir_arg_count(relay_function, arg_count):
5757
mod = tvm.IRModule()
5858
mod["main"] = relay_function()
5959
mod = relay.transform.InferType()(mod)
60-
tir_mod = _lower_to_tir(mod["main"])[0]
60+
tir_mod = _lower_to_tir(mod["main"])
6161
primfunc = tir_mod["main"]
6262
assert len(primfunc.params) == arg_count
6363

tests/python/contrib/test_ethosu/test_encode_constants.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tvm
2222
from tvm import relay
2323
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator
24+
from tvm.relay.backend.contrib.ethosu.codegen import collect_consts
2425
from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir
2526
from tvm.relay.backend.contrib.ethosu.tir.scheduler import (
2627
OperatorCompute,
@@ -140,12 +141,12 @@ def _get_func():
140141
}
141142
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
142143
func = _get_func()
143-
mod, consts = _lower_to_tir(func, cascader=_planner)
144+
mod = _lower_to_tir(func, cascader=_planner)
144145
script = mod.script()
145146
test_mod = tvm.script.from_source(script)
146147
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
147148

148-
test_const_size = [value.size for value in list(consts.values())]
149+
test_const_size = [value.size for value in collect_consts(test_mod).values()]
149150
assert reference_const_sizes.sort() == test_const_size.sort()
150151

151152

@@ -242,12 +243,12 @@ def _get_func():
242243
}
243244
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
244245
func = _get_func()
245-
mod, consts = _lower_to_tir(func, cascader=_cascader)
246+
mod = _lower_to_tir(func, cascader=_cascader)
246247
script = mod.script()
247248
test_mod = tvm.script.from_source(script)
248249
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
249250

250-
test_const_size = [value.size for value in list(consts.values())]
251+
test_const_size = [value.size for value in collect_consts(test_mod).values()]
251252
assert reference_const_sizes.sort() == test_const_size.sort()
252253

253254

@@ -339,13 +340,13 @@ def _get_func():
339340
}
340341
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
341342
func = _get_func()
342-
mod, consts = _lower_to_tir(func)
343+
mod = _lower_to_tir(func)
343344

344345
script = mod.script()
345346
test_mod = tvm.script.from_source(script)
346347
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
347348

348-
test_const_size = [value.size for value in list(consts.values())]
349+
test_const_size = [value.size for value in collect_consts(test_mod).values()]
349350
assert reference_const_sizes.sort() == test_const_size.sort()
350351

351352

@@ -474,13 +475,13 @@ def _get_func():
474475
}
475476
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
476477
func = _get_func()
477-
mod, consts = _lower_to_tir(func, cascader=_planner)
478+
mod = _lower_to_tir(func, cascader=_planner)
478479

479480
script = mod.script()
480481
test_mod = tvm.script.from_source(script)
481482
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
482483

483-
test_const_size = [value.size for value in list(consts.values())]
484+
test_const_size = [value.size for value in collect_consts(test_mod).values()]
484485
assert reference_const_sizes.sort() == test_const_size.sort()
485486

486487

tests/python/contrib/test_ethosu/test_remove_concatenates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _get_func():
7373
return func
7474

7575
func = _get_func()
76-
mod, _ = _lower_to_tir(func)
76+
mod = _lower_to_tir(func)
7777
script = mod.script()
7878
test_mod = tvm.script.from_source(script)
7979

0 commit comments

Comments
 (0)