Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Codegen for Arm(R) Ethos(TM)-U NPU"""
from collections import defaultdict
from typing import List, Callable
from typing import List, Callable, Dict

from ethosu.vela import api as vapi
import tvm
Expand Down Expand Up @@ -720,6 +720,34 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
return tir_mod


def collect_consts(mod: tvm.IRModule) -> Dict[tvm.tir.Var, tvm.nd.NDArray]:
"""Collect any AllocateConst

Parameters
----------
mod: tvm.IRModule

The module to inspect.

Returns
-------
const_dict: Dict[tvm.tir.Var, tvm.nd.NDArray]

A map from buffer var to NDArray, from AllocateConst nodes in
the module
"""
constants = {}

def _visit(stmt):
if isinstance(stmt, tvm.tir.AllocateConst):
constants[stmt.buffer_var] = stmt.data

for func in mod.functions.values():
tvm.tir.stmt_functor.post_order_visit(func.body, _visit)

return constants


@tvm._ffi.register_func("relay.ext.ethos-u.primfunc_to_artifact")
def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact:
"""
Expand All @@ -739,13 +767,12 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
for the microNPU
"""
symbol = str(primfunc.attrs["global_symbol"])
const_dict = primfunc.attrs["ethos-u.constants"]
tir_mod = tvm.IRModule()
tir_mod[symbol] = primfunc

const_dict_np = dict()
for buffer_var in const_dict.keys():
const_dict_np[buffer_var] = const_dict[buffer_var].numpy()
const_dict_np = {
buffer_var: ndarray.numpy() for buffer_var, ndarray in collect_consts(tir_mod).items()
}

cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(tir_mod, const_dict_np)
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)
20 changes: 5 additions & 15 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def lower_ethosu(sch, args, const_dict, name="main"):
-------
mod : tvm.IRModule
The lowered TIR module.
const_dict : dict of int to numpy.ndarray
The modified constant dictionary.

"""
if not isinstance(args, list):
args = list(args.inputs) + list(args.outputs)
Expand Down Expand Up @@ -101,8 +98,8 @@ def lower_ethosu(sch, args, const_dict, name="main"):

mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = ethosu_passes.AnnotateAllocates()(mod)
mod, const_dict = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
return mod, const_dict
mod = ethosu_passes.CreatePrimFuncWithoutConstants(const_dict)(mod)
return mod


def lower_to_te(prim_func):
Expand Down Expand Up @@ -200,15 +197,11 @@ def __init__(self, scheduler):
def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""Lower NPU functions to TIR."""

tir_mod, const_dict = _lower_to_tir(func, self.scheduler)

for param in const_dict.keys():
const_dict[param] = tvm.nd.array(const_dict[param])
tir_mod = _lower_to_tir(func, self.scheduler)

compiler_name = "ethos-u"
primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
return primfunc

Expand All @@ -233,14 +226,11 @@ def _lower_to_tir(func, cascader=None):
-------
mod : tvm.IRModule
The lowered TIR module.
consts : dict of int to numpy.ndarray
A dict of the extracted constants keyed by their param index.

"""
func, consts = extract_constants(func)
mod = tvm.IRModule.from_expr(func)
func = relay.transform.InferType()(mod)["main"]
cached_func = lower_to_te(func)
s = schedule(cached_func, consts, cascader)
mod, consts = lower_ethosu(s, cached_func, consts)
return mod, consts
mod = lower_ethosu(s, cached_func, consts)
return mod
26 changes: 15 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,31 +911,35 @@ def _ftransform(f, mod, ctx):
def CreatePrimFuncWithoutConstants(const_dict):
"""
This pass will remove arguments that are constants
from PrimFunc Args. These should be replaced properly
with tir.allocate_const when it becomes available.
from PrimFunc Args, replacing them with tir.allocate_const.

It also modifies the constant dictionary to
rewrite the keys as the actual tir.Vars that are params
rather than the index because this pass removes PrimFunc
arguments that represent constants.
"""

new_const_dict = dict()

def _ftransform(f, mod, ctx):
new_params = list()
new_buffer_map = dict()
for param_idx in const_dict.keys():
# We are using buffer_var to key the constants as
# PrimFunc params of constants will be removed.
new_const_dict[f.buffer_map[f.params[param_idx]].data] = const_dict[param_idx]

new_body = f.body

for i, param in enumerate(f.params):
if i not in const_dict.keys():
if i in const_dict:
const_np = const_dict[i]
const_ndarray = tvm.nd.array(const_np, device=tvm.cpu())
buf = f.buffer_map[param]
new_body = tvm.tir.AllocateConst(
buf.data, buf.dtype, buf.shape, const_ndarray, new_body
)
else:
new_params.append(param)
new_buffer_map[param] = f.buffer_map[param]

return tvm.tir.PrimFunc(
new_params,
f.body,
new_body,
f.ret_type,
new_buffer_map,
f.attrs,
Expand All @@ -947,7 +951,7 @@ def _create_primfunc_without_constants(mod):
_ftransform, opt_level=0, name="tir.contrib.ethos-u.CreatePrimFuncWithoutConstants"
)
mod = transform_func(mod)
return mod, new_const_dict
return mod

return _create_primfunc_without_constants

Expand Down
32 changes: 22 additions & 10 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2347,14 +2347,26 @@ def partition_for_ethosu(
mod["main"] = bind_params_by_name(mod["main"], params)

pattern = relay.op.contrib.get_pattern_table("ethos-u")
mod = relay.transform.InferType()(mod)
mod = codegen.replicate_pads(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.PartitionGraph(mod_name)(mod)
mod = relay.transform.InferType()(mod)
mod = preprocess.preprocess_ext_io()(mod)

seq = tvm.ir.transform.Sequential(
[
relay.transform.InferType(),
tvm.ir.transform.module_pass(
lambda mod, context: codegen.replicate_pads(mod),
opt_level=0,
name="ethosu.replicate_pads",
),
relay.transform.InferType(),
relay.transform.MergeComposite(pattern),
relay.transform.AnnotateTarget("ethos-u"),
relay.transform.MergeCompilerRegions(),
relay.transform.InferType(),
relay.transform.PartitionGraph(mod_name),
relay.transform.InferType(),
preprocess.preprocess_ext_io(),
],
name="partition_for_ethosu",
)
mod = seq(mod)

return mod
3 changes: 0 additions & 3 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,6 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
Array<CompilationArtifact> compile_artifacts;
for (const auto& kv : mod->functions) {
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(kv.second);
Optional<Map<Integer, runtime::NDArray>> params =
prim_func->GetAttr<Map<Integer, runtime::NDArray>>("ethos-u.constants");
ICHECK(params) << "microNPU params should be present";
auto primfunc_to_artifact_pf =
tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact");
ICHECK(primfunc_to_artifact_pf);
Expand Down
8 changes: 5 additions & 3 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> ext
}
ICHECK(body.defined());
ICHECK(data_or_idx.defined());
ICHECK(annotations.defined());

ObjectPtr<AllocateConstNode> node = make_object<AllocateConstNode>();
node->buffer_var = std::move(buffer_var);
Expand Down Expand Up @@ -323,9 +324,10 @@ int64_t AllocateConstNode::ConstantAllocationSize(const Array<PrimExpr>& extents
}
TVM_REGISTER_GLOBAL("tir.AllocateConst")
.set_body_typed([](Var buffer_var, DataType dtype, Array<PrimExpr> extents,
ObjectRef data_or_idx, Stmt body, Map<String, ObjectRef> annotations,
Span span) {
return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations, span);
ObjectRef data_or_idx, Stmt body,
Optional<Map<String, ObjectRef>> annotations, Span span) {
return AllocateConst(buffer_var, dtype, extents, data_or_idx, body,
annotations.value_or(Map<String, ObjectRef>()), span);
});

TVM_REGISTER_NODE_TYPE(AllocateConstNode);
Expand Down
13 changes: 9 additions & 4 deletions tests/python/contrib/test_ethosu/cascader/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _compile_model(relay_function):
mod = tvm.IRModule()
mod["main"] = relay_function
mod = relay.transform.InferType()(mod)
tir_mod = _lower_to_tir(mod["main"], _ethos_u55_cascader())[0]
func = mod["main"]
cascader = _ethos_u55_cascader()
tir_mod = _lower_to_tir(func, cascader)
return tir_mod["main"]


Expand Down Expand Up @@ -109,7 +111,7 @@ def test_single_conv_compute_cycles_hint():
for single convolution.
"""
primfunc = _compile_model(_create_single_conv2d())
ops = primfunc.body.body.seq
ops = primfunc.body.body.body.seq
compute_cycles_hints = [2944, 320]
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
assert op.attr_key == "pragma_compute_cycles_hint"
Expand All @@ -122,7 +124,10 @@ def test_double_conv_compute_cycles_hint():
for double convolution.
"""
primfunc = _compile_model(_create_double_conv2d())
ops = primfunc.body.body.body.body.seq

ops = primfunc.body
while not isinstance(ops, tvm.tir.SeqStmt):
ops = ops.body
compute_cycles_hints = [2944, 1408, 320, 240]
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
assert op.attr_key == "pragma_compute_cycles_hint"
Expand All @@ -135,7 +140,7 @@ def test_scalar_add_compute_cycles_hint():
for add with scalar values.
"""
primfunc = _compile_model(_create_scalar_add())
ops = primfunc.body.body.seq
ops = primfunc.body.body.body.seq

compute_cycles_hints = [16, 24]
for op, compute_cycle_hint in zip(ops, compute_cycles_hints):
Expand Down
59 changes: 59 additions & 0 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,3 +789,62 @@ def make_ethosu_unary_elementwise(
ofm_layout=ofm_layout,
)
return ethosu_unary_elementwise


def copy_allocate_const_data(test_mod: tvm.IRModule, reference_mod: tvm.IRModule) -> tvm.IRModule:
"""For testing purposes, copy the NDArray into refernece

NDArray does not implement SEqual, so StructuralEqual defaults to
pointer equality. Since the reference module and the test module
were generated separately, they won't have the same NDArray.
Therefore, copy it over before StructuralEqual.
"""

def collect_ndarray(func):
output = []

def fvisit(node):
if isinstance(node, tvm.tir.AllocateConst):
output.append(node.data)

tvm.tir.stmt_functor.post_order_visit(func.body, fvisit)

return output

def inject_ndarray(func, data_arrays):
def fvisit(node):
if data_arrays and isinstance(node, tvm.tir.AllocateConst):
data = data_arrays.pop(0)
return tvm.tir.AllocateConst(
buffer_var=node.buffer_var,
dtype=node.dtype,
extents=node.extents,
data_or_idx=data,
body=node.body,
annotations=node.annotations,
span=node.span,
)
else:
return node

body = tvm.tir.stmt_functor.ir_transform(func.body, lambda node: None, fvisit)
if body.same_as(func.body):
return func
else:
return tvm.tir.PrimFunc(
func.params, body, func.ret_type, func.buffer_map, func.attrs, func.span
)

data_arrays = {
gvar.name_hint: collect_ndarray(func)
for gvar, func in test_mod.functions.items()
if isinstance(func, tvm.tir.PrimFunc)
}

new_module = {}
for gvar, func in reference_mod.functions.items():
if isinstance(func, tvm.tir.PrimFunc):
if gvar.name_hint in data_arrays:
func = inject_ndarray(func, data_arrays[gvar.name_hint])
new_module[gvar] = func
return tvm.IRModule(new_module)
9 changes: 5 additions & 4 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,10 +1081,11 @@ def depthwise_conv2d(x):
'__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_cms_data_data'
in source
)
assert (
'__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_weights'
in source
)
# The weights are now encoded by TVM in the AllocateConst node.
# assert (
# '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t tvmgen_default_ethos_u_main_0_weights'
# in source
# )


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_ethosu/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_lower_to_tir_arg_count(relay_function, arg_count):
mod = tvm.IRModule()
mod["main"] = relay_function()
mod = relay.transform.InferType()(mod)
tir_mod = _lower_to_tir(mod["main"])[0]
tir_mod = _lower_to_tir(mod["main"])
primfunc = tir_mod["main"]
assert len(primfunc.params) == arg_count

Expand Down
Loading