Skip to content

Commit f1a16c8

Browse files
committed
[Ethos-U] Replace ethos-u.constants with AllocateConst
Previously, constants for ethos-u were tracked using a function attribute `"ethos-u.constants"`. This predates the introduction of `AllocateConst`, and had comments indicating that it should be replaced with `AllocateConst` when possible. To minimize impact to existing passes, this commit preserves the `"ethos-u.constants"` attribute during ethosu-specific lowering passes. The attribute is converted to `AllocateConst` at the end of the `lower_ethosu` pass, just prior to lowering with the usual TIR passes.
1 parent dfd525b commit f1a16c8

20 files changed

+361
-153
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
1818
from collections import defaultdict
19-
from typing import List, Callable
19+
from typing import List, Callable, Dict
2020

2121
from ethosu.vela import api as vapi
2222
import tvm
@@ -720,6 +720,34 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
720720
return tir_mod
721721

722722

723+
def collect_consts(mod: tvm.IRModule) -> Dict[tvm.tir.Var, tvm.nd.NDArray]:
724+
"""Collect any AllocateConst
725+
726+
Parameters
727+
----------
728+
mod: tvm.IRModule
729+
730+
The module to inspect.
731+
732+
Returns
733+
-------
734+
const_dict: Dict[tvm.tir.Var, tvm.nd.NDArray]
735+
736+
A map from buffer var to NDArray, from AllocateConst nodes in
737+
the module
738+
"""
739+
constants = {}
740+
741+
def _visit(stmt):
742+
if isinstance(stmt, tvm.tir.AllocateConst):
743+
constants[stmt.buffer_var] = stmt.data
744+
745+
for func in mod.functions.values():
746+
tvm.tir.stmt_functor.post_order_visit(func.body, _visit)
747+
748+
return constants
749+
750+
723751
@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
724752
def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact:
725753
"""
@@ -739,13 +767,12 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
739767
for the microNPU
740768
"""
741769
symbol = str(primfunc.attrs["global_symbol"])
742-
const_dict = primfunc.attrs["ethos-u.constants"]
743770
tir_mod = tvm.IRModule()
744771
tir_mod[symbol] = primfunc
745772

746-
const_dict_np = dict()
747-
for buffer_var in const_dict.keys():
748-
const_dict_np[buffer_var] = const_dict[buffer_var].numpy()
773+
const_dict_np = {
774+
buffer_var: ndarray.numpy() for buffer_var, ndarray in collect_consts(tir_mod).items()
775+
}
749776

750777
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np)
751778
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,6 @@ def lower_ethosu(sch, args, const_dict, name="main"):
4949
-------
5050
mod : tvm.IRModule
5151
The lowered TIR module.
52-
const_dict : dict of int to numpy.ndarray
53-
The modified constant dictionary.
54-
5552
"""
5653
if not isinstance(args, list):
5754
args = list(args.inputs) + list(args.outputs)
@@ -101,8 +98,8 @@ def lower_ethosu(sch, args, const_dict, name="main"):
10198

10299
mod = tvm.tir.transform.RemoveNoOp()(mod)
103100
mod = ethosu_passes.AnnotateAllocates()(mod)
104-
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
105-
return mod, const_dict
101+
mod = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
102+
return mod
106103

107104

108105
def lower_to_te(prim_func):
@@ -200,15 +197,11 @@ def __init__(self, scheduler):
200197
def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
201198
"""Lower NPU functions to TIR."""
202199

203-
tir_mod, const_dict = _lower_to_tir(func, self.scheduler)
204-
205-
for param in const_dict.keys():
206-
const_dict[param] = tvm.nd.array(const_dict[param])
200+
tir_mod = _lower_to_tir(func, self.scheduler)
207201

208202
compiler_name = "ethos-u"
209203
primfunc = tir_mod["main"]
210204
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
211-
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
212205
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
213206
return primfunc
214207

@@ -233,14 +226,11 @@ def _lower_to_tir(func, cascader=None):
233226
-------
234227
mod : tvm.IRModule
235228
The lowered TIR module.
236-
consts : dict of int to numpy.ndarray
237-
A dict of the extracted constants keyed by their param index.
238-
239229
"""
240230
func, consts = extract_constants(func)
241231
mod = tvm.IRModule.from_expr(func)
242232
func = relay.transform.InferType()(mod)["main"]
243233
cached_func = lower_to_te(func)
244234
s = schedule(cached_func, consts, cascader)
245-
mod, consts = lower_ethosu(s, cached_func, consts)
246-
return mod, consts
235+
mod = lower_ethosu(s, cached_func, consts)
236+
return mod

python/tvm/relay/backend/contrib/ethosu/tir/passes.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -911,31 +911,35 @@ def _ftransform(f, mod, ctx):
911911
def CreatePrimFuncWithoutConstants(const_dict):
912912
"""
913913
This pass will remove arguments that are constants
914-
from PrimFunc Args. These should be replaced properly
915-
with tir.allocate_const when it becomes available.
914+
from PrimFunc Args, replacing them with tir.allocate_const.
916915
917916
It also modifies the constant dictionary to
918917
rewrite the keys as the actual tir.Vars that are params
919918
rather than the index because this pass removes PrimFunc
920919
arguments that represent constants.
921920
"""
922921

923-
new_const_dict = dict()
924-
925922
def _ftransform(f, mod, ctx):
926923
new_params = list()
927924
new_buffer_map = dict()
928-
for param_idx in const_dict.keys():
929-
# We are using buffer_var to key the constants as
930-
# PrimFunc params of constants will be removed.
931-
new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]
925+
926+
new_body = f.body
927+
932928
for i, param in enumerate(f.params):
933-
if i not in const_dict.keys():
929+
if i in const_dict:
930+
const_np = const_dict[i]
931+
const_ndarray = tvm.nd.array(const_np, device=tvm.cpu())
932+
buf = f.buffer_map[param]
933+
new_body = tvm.tir.AllocateConst(
934+
buf.data, buf.dtype, buf.shape, const_ndarray, new_body
935+
)
936+
else:
934937
new_params.append(param)
935938
new_buffer_map[param] = f.buffer_map[param]
939+
936940
return tvm.tir.PrimFunc(
937941
new_params,
938-
f.body,
942+
new_body,
939943
f.ret_type,
940944
new_buffer_map,
941945
f.attrs,
@@ -947,7 +951,7 @@ def _create_primfunc_without_constants(mod):
947951
_ftransform, opt_level=0, name="tir.contrib.ethos-u.CreatePrimFuncWithoutConstants"
948952
)
949953
mod = transform_func(mod)
950-
return mod, new_const_dict
954+
return mod
951955

952956
return _create_primfunc_without_constants
953957

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,14 +2347,26 @@ def partition_for_ethosu(
23472347
mod["main"] = bind_params_by_name(mod["main"], params)
23482348

23492349
pattern = relay.op.contrib.get_pattern_table("ethos-u")
2350-
mod = relay.transform.InferType()(mod)
2351-
mod = codegen.replicate_pads(mod)
2352-
mod = relay.transform.InferType()(mod)
2353-
mod = relay.transform.MergeComposite(pattern)(mod)
2354-
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
2355-
mod = relay.transform.MergeCompilerRegions()(mod)
2356-
mod = relay.transform.InferType()(mod)
2357-
mod = relay.transform.PartitionGraph(mod_name)(mod)
2358-
mod = relay.transform.InferType()(mod)
2359-
mod = preprocess.preprocess_ext_io()(mod)
2350+
2351+
seq = tvm.ir.transform.Sequential(
2352+
[
2353+
relay.transform.InferType(),
2354+
tvm.ir.transform.module_pass(
2355+
lambda mod, context: codegen.replicate_pads(mod),
2356+
opt_level=0,
2357+
name="ethosu.replicate_pads",
2358+
),
2359+
relay.transform.InferType(),
2360+
relay.transform.MergeComposite(pattern),
2361+
relay.transform.AnnotateTarget("ethos-u"),
2362+
relay.transform.MergeCompilerRegions(),
2363+
relay.transform.InferType(),
2364+
relay.transform.PartitionGraph(mod_name),
2365+
relay.transform.InferType(),
2366+
preprocess.preprocess_ext_io(),
2367+
],
2368+
name="partition_for_ethosu",
2369+
)
2370+
mod = seq(mod)
2371+
23602372
return mod

src/relay/backend/contrib/ethosu/codegen.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,6 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
307307
Array<CompilationArtifact> compile_artifacts;
308308
for (const auto& kv : mod->functions) {
309309
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(kv.second);
310-
Optional<Map<Integer, runtime::NDArray>> params =
311-
prim_func->GetAttr<Map<Integer, runtime::NDArray>>("ethos-u.constants");
312-
ICHECK(params) << "microNPU params should be present";
313310
auto primfunc_to_artifact_pf =
314311
tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact");
315312
ICHECK(primfunc_to_artifact_pf);

src/tir/ir/stmt.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> ext
287287
}
288288
ICHECK(body.defined());
289289
ICHECK(data_or_idx.defined());
290+
ICHECK(annotations.defined());
290291

291292
ObjectPtr<AllocateConstNode> node = make_object<AllocateConstNode>();
292293
node->buffer_var = std::move(buffer_var);
@@ -323,9 +324,10 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array<PrimExpr>& extents
323324
}
324325
TVM_REGISTER_GLOBAL("tir.AllocateConst")
325326
.set_body_typed([](Var buffer_var, DataType dtype, Array<PrimExpr> extents,
326-
ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations,
327-
Span span) {
328-
return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations, span);
327+
ObjectRef data_or_idx, Stmt body,
328+
Optional<Map<String, ObjectRef>> annotations, Span span) {
329+
return AllocateConst(buffer_var, dtype, extents, data_or_idx, body,
330+
annotations.value_or(Map<String, ObjectRef>()), span);
329331
});
330332

331333
TVM_REGISTER_NODE_TYPE(AllocateConstNode);

tests/python/contrib/test_ethosu/cascader/test_integration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def _compile_model(relay_function):
7474
mod = tvm.IRModule()
7575
mod["main"] = relay_function
7676
mod = relay.transform.InferType()(mod)
77-
tir_mod = _lower_to_tir(mod["main"], _ethos_u55_cascader())[0]
77+
func = mod["main"]
78+
cascader = _ethos_u55_cascader()
79+
tir_mod = _lower_to_tir(func, cascader)
7880
return tir_mod["main"]
7981

8082

@@ -109,7 +111,7 @@ def test_single_conv_compute_cycles_hint():
109111
for single convolution.
110112
"""
111113
primfunc = _compile_model(_create_single_conv2d())
112-
ops = primfunc.body.body.seq
114+
ops = primfunc.body.body.body.seq
113115
compute_cycles_hints = [2944, 320]
114116
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
115117
assert op.attr_key == "pragma_compute_cycles_hint"
@@ -135,7 +137,7 @@ def test_scalar_add_compute_cycles_hint():
135137
for add with scalar values.
136138
"""
137139
primfunc = _compile_model(_create_scalar_add())
138-
ops = primfunc.body.body.seq
140+
ops = primfunc.body.body.body.seq
139141

140142
compute_cycles_hints = [16, 24]
141143
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):

tests/python/contrib/test_ethosu/infra.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,62 @@ def make_ethosu_unary_elementwise(
789789
ofm_layout=ofm_layout,
790790
)
791791
return ethosu_unary_elementwise
792+
793+
794+
def copy_allocate_const_data(test_mod: tvm.IRModule, reference_mod: tvm.IRModule) -> tvm.IRModule:
795+
"""For testing purposes, copy the NDArray into refernece
796+
797+
NDArray does not implement SEqual, so StructuralEqual defaults to
798+
pointer equality. Since the reference module and the test module
799+
were generated separately, they won't have the same NDArray.
800+
Therefore, copy it over before StructuralEqual.
801+
"""
802+
803+
def collect_ndarray(func):
804+
output = []
805+
806+
def fvisit(node):
807+
if isinstance(node, tvm.tir.AllocateConst):
808+
output.append(node.data)
809+
810+
tvm.tir.stmt_functor.post_order_visit(func.body, fvisit)
811+
812+
return output
813+
814+
def inject_ndarray(func, data_arrays):
815+
def fvisit(node):
816+
if data_arrays and isinstance(node, tvm.tir.AllocateConst):
817+
data = data_arrays.pop(0)
818+
return tvm.tir.AllocateConst(
819+
buffer_var=node.buffer_var,
820+
dtype=node.dtype,
821+
extents=node.extents,
822+
data_or_idx=data,
823+
body=node.body,
824+
annotations=node.annotations,
825+
span=node.span,
826+
)
827+
else:
828+
return node
829+
830+
body = tvm.tir.stmt_functor.ir_transform(func.body, lambda node: None, fvisit)
831+
if body.same_as(func.body):
832+
return func
833+
else:
834+
return tvm.tir.PrimFunc(
835+
func.params, body, func.ret_type, func.buffer_map, func.attrs, func.span
836+
)
837+
838+
data_arrays = {
839+
gvar.name_hint: collect_ndarray(func)
840+
for gvar, func in test_mod.functions.items()
841+
if isinstance(func, tvm.tir.PrimFunc)
842+
}
843+
844+
new_module = {}
845+
for gvar, func in reference_mod.functions.items():
846+
if isinstance(func, tvm.tir.PrimFunc):
847+
if gvar.name_hint in data_arrays:
848+
func = inject_ndarray(func, data_arrays[gvar.name_hint])
849+
new_module[gvar] = func
850+
return tvm.IRModule(new_module)

tests/python/contrib/test_ethosu/test_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_lower_to_tir_arg_count(relay_function, arg_count):
5757
mod = tvm.IRModule()
5858
mod["main"] = relay_function()
5959
mod = relay.transform.InferType()(mod)
60-
tir_mod = _lower_to_tir(mod["main"])[0]
60+
tir_mod = _lower_to_tir(mod["main"])
6161
primfunc = tir_mod["main"]
6262
assert len(primfunc.params) == arg_count
6363

0 commit comments

Comments
 (0)