Skip to content

Commit 31be726

Browse files
[microNPU][ETHOSU] Fix minimum buffer size (#15104)
Fix minimum buffer size for DMA operations according to alignment.
1 parent e280e01 commit 31be726

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,17 +926,20 @@ def _create_npu_dma_op(serial_copy):
926926
"""This is a helper function to capture the list of arguments
927927
to create a NpuDmaOperation object"""
928928
data_type_bytes = np.iinfo(np.dtype(serial_copy.read_address.dtype)).bits // 8
929+
length = int(serial_copy.length.value) * data_type_bytes
930+
# The buffer size in bytes must be at least 16 bytes
931+
length = max(length, 16)
929932
src = vapi.NpuAddressRange(
930933
# region will be updated later
931934
region=0,
932935
address=serial_copy.read_address,
933-
length=int(serial_copy.length.value) * data_type_bytes,
936+
length=length,
934937
)
935938
dest = vapi.NpuAddressRange(
936939
# region will be updated later
937940
region=0,
938941
address=serial_copy.write_address,
939-
length=int(serial_copy.length.value) * data_type_bytes,
942+
length=length,
940943
)
941944
return vapi.NpuDmaOperation(src, dest)
942945

@@ -1076,7 +1079,6 @@ def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBina
10761079
def translate_ethosu_unary_elementwise(
10771080
tir_extern_call: tvm.tir.Call,
10781081
) -> vapi.NpuElementWiseOperation:
1079-
10801082
"""This function will translate a tir extern_call
10811083
as produced by Relay to TIR compilation.
10821084
Parameters

tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ def main(placeholder_3: T.Buffer((8192,), "int8"), ethosu_conv2d_1: T.Buffer((20
8484
# fmt: on
8585

8686

87+
# fmt: off
88+
"""A tir test case with copy operation having a buffer size less than the minimum for a DMA operation"""
89+
@tvm.script.ir_module
90+
class CopyLessMinimal:
91+
@T.prim_func
92+
def main(ethos_u_0_i0: T.Buffer((1, 4), "int8"), ethosu_write: T.Buffer((1, 4), "int8")):
93+
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
94+
p1_global = T.allocate([4], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)})
95+
ethosu_write_1 = T.allocate([4], "int8", "global", annotations={"disable_lower_builtin": T.bool(True)})
96+
p1 = T.Buffer((4,), "int8")
97+
p1_global_1 = T.Buffer((4,), "int8", data=p1_global)
98+
T.call_extern("handle", "ethosu_copy", p1[0], 4, p1_global_1[0])
99+
ethos_u_0_i0_1 = T.Buffer((4,), "int8", data=ethos_u_0_i0.data)
100+
ethosu_write_2 = T.Buffer((4,), "int8", data=ethosu_write_1, align=4)
101+
T.call_extern("handle", "ethosu_binary_elementwise", "int8", 1, 1, 4, 1, 0, 1, ethos_u_0_i0_1[0], 0, 0, 0, T.float32(0.0039170472882688046), -128, "NHWC", 1, 1, 1, "int8", 1, 1, 4, 1, 0, 1, p1_global_1[0], 0, 0, 0, T.float32(0.0028046639636158943), -128, "NHWC", 1, 1, 1, "int8", 1, 1, 4, 1, 0, 1, ethosu_write_2[0], 0, 0, 0, T.float32(0.0067217112518846989), -128, "NHWC", 1, 1, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, 0, 0, 0)
102+
ethosu_write_3 = T.Buffer((4,), "int8", data=ethosu_write.data)
103+
T.call_extern("handle", "ethosu_identity", "int8", 1, 4, 1, 1, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "int8", 1, 4, 1, 1, 0, 4, ethosu_write_3[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1, 1, 1, "AVG", 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0)
104+
# fmt: on
105+
106+
87107
# fmt: off
88108
"""A TIR test module of weight streaming"""
89109
@tvm.script.ir_module
@@ -658,6 +678,21 @@ def populate_ethosu_copy_calls(stmt):
658678
},
659679
],
660680
},
681+
{
682+
# Mod contains a copy operation with a buffer size of 4 bytes and it should be replaced by 16
683+
"tir_module": CopyLessMinimal,
684+
"param_dict": {
685+
1: np.random.randint(np.iinfo("int8").min, np.iinfo("int8").max, [1, 4], "int8"),
686+
},
687+
# Reference outputs
688+
"ref": [
689+
{
690+
"src": "p1",
691+
"dest": "p1_global_1",
692+
"length": 16,
693+
},
694+
],
695+
},
661696
]
662697

663698
for test_case in test_cases:

0 commit comments

Comments
 (0)