Skip to content

Commit 4e60749

Browse files
committed
[microNPU] Refactor base address determination to codegen
* Renaming runtime_allocate to be scratch again. * Docstring adjustments. Change-Id: Ife8baf97f3dc9348718bd03e62549169a466fc34
1 parent 26ea7dc commit 4e60749

File tree

3 files changed

+37
-36
lines changed

3 files changed

+37
-36
lines changed

python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ class BufferType(Enum):
3636

3737
constant = auto()
3838
input_or_output = auto()
39-
runtime_allocate = auto()
39+
scratch = auto()
4040
input = auto()
4141
output = auto()
4242
shram = auto()
4343

4444

4545
_REGION_MAP = {
4646
BufferType.constant: 0,
47-
BufferType.runtime_allocate: 1,
47+
BufferType.scratch: 1,
4848
BufferType.input: 3,
4949
BufferType.output: 4,
5050
BufferType.shram: int((1 << 8) | (3 << 0)),
@@ -103,23 +103,23 @@ def translate(tir_module, params):
103103
An hex string of the bytes that includes concat'd
104104
encoded weights, encoded biases and scales.
105105
base_addresses : List[util.BaseAddress]
106-
base addresses
106+
base addresses to be used by the driver
107107
"""
108108

109109
buffer_info = extract_buffer_info(tir_module, params)
110110
call_extern_list = extract_call_extern_list(tir_module)
111111
_npu_ops = list()
112112
for call_extern in call_extern_list:
113113
_npu_ops.append(translate_ethosu_tir_call_extern(call_extern))
114-
_npu_ops, constant_data, runtime_allocation_size = assign_addresses(buffer_info, _npu_ops)
114+
_npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops)
115115
base_addresses = extract_param_base_addresses(tir_module, buffer_info)
116-
if runtime_allocation_size > 0:
116+
if scratch_size > 0:
117117
base_addresses.append(
118118
util.BaseAddress(
119-
"runtime_allocation",
119+
"scratch",
120120
None,
121-
_REGION_MAP[BufferType.runtime_allocate],
122-
runtime_allocation_size,
121+
_REGION_MAP[BufferType.scratch],
122+
scratch_size,
123123
True,
124124
)
125125
)
@@ -248,7 +248,7 @@ def populate_allocate_buffer_info(stmt):
248248
if storage_scope == "local":
249249
buffer_type = BufferType.shram
250250
else:
251-
buffer_type = BufferType.runtime_allocate
251+
buffer_type = BufferType.scratch
252252
buffer_info[allocate.buffer_var] = BufferInfo(
253253
None,
254254
allocate.extents,
@@ -280,7 +280,7 @@ def assign_addresses(buffer_info, npu_ops):
280280
A list of Vela NpuOps with addesses within scratch and constant buffers
281281
constant_tensor : NDArray
282282
A unified constant data array of uint8 as the constant buffer
283-
runtime_allocation_size : int
283+
scratch_size : int
284284
The size of the scratch tensor.
285285
"""
286286

@@ -327,7 +327,7 @@ def classify_io(buffer):
327327

328328
raise ValueError(f"Unused IO : {buffer} in tir module.")
329329

330-
runtime_allocation_size = 0
330+
scratch_size = 0
331331
constant_hex_data = []
332332
total_constant_len = 0
333333
buffer_addresses = dict()
@@ -352,7 +352,9 @@ def classify_io(buffer):
352352
assert buffer_type in (BufferType.input, BufferType.output)
353353
address = 0
354354
buffer_addresses[_buffer] = (address, buffer_type)
355-
buffer_info[_buffer] = BufferInfo(None, info.dtype, info.dtype, buffer_type)
355+
buffer_info[_buffer] = BufferInfo(
356+
values=None, shape=info.dtype, dtype=info.dtype, btype=buffer_type
357+
)
356358
elif info.btype == BufferType.shram:
357359
accl_config = util.get_accelerator_config()
358360
arch_config = get_accelerator_arch_config(accl_config)
@@ -363,9 +365,9 @@ def classify_io(buffer):
363365
size_in_bytes = int(dtype_bytes * np.prod(list(info.shape)))
364366
# Every memory address the NPU access have to be 16 byte aligned
365367
size_in_bytes = util.round_up(size_in_bytes, 16)
366-
assert info.btype == BufferType.runtime_allocate
367-
address = runtime_allocation_size
368-
runtime_allocation_size += size_in_bytes
368+
assert info.btype == BufferType.scratch
369+
address = scratch_size
370+
scratch_size += size_in_bytes
369371
buffer_addresses[_buffer] = (address, info.btype)
370372

371373
for npu_op in npu_ops:
@@ -382,7 +384,7 @@ def classify_io(buffer):
382384
return (
383385
npu_ops,
384386
constant_data,
385-
runtime_allocation_size,
387+
scratch_size,
386388
)
387389

388390

src/relay/backend/contrib/ethosu/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace ethosu {
3636

3737
/*!
3838
* \brief Base addresses are input pointers to
39-
* the driver that get accessed by produced
39+
* the driver that get accessed by the command stream
40+
* using offsets to read/write data.
4041
*/
4142
struct BaseAddressNode : public Object {
4243
/*! \brief The identifier, usually it the param name of the PrimFunc that gets lowered */

tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,12 @@ def test_buffer_info_extraction():
230230
"ethosu_conv2d_2": (
231231
[1024],
232232
"uint8",
233-
tir_to_cs_translator.BufferType.runtime_allocate,
233+
tir_to_cs_translator.BufferType.scratch,
234234
),
235235
"ethosu_conv2d_3": (
236236
[2048],
237237
"uint8",
238-
tir_to_cs_translator.BufferType.runtime_allocate,
238+
tir_to_cs_translator.BufferType.scratch,
239239
),
240240
},
241241
},
@@ -776,15 +776,15 @@ def _check_buffer(address, region, length, buffer_var):
776776
original tir buffers.
777777
- If its constant, this will check
778778
the slice in the constant tensor has the values.
779-
- If its runtime_allocation, this will check
780-
the slice is within runtime_allocation and does not have conflicts
781-
with other runtime_allocation tensors.
779+
- If its scratch, this will check
780+
the slice is within scratch and does not have conflicts
781+
with other scratch tensors.
782782
- If its input/output, this will check the
783783
address is zero
784784
"""
785785
inverse_region_map = {
786786
0: tir_to_cs_translator.BufferType.constant,
787-
1: tir_to_cs_translator.BufferType.runtime_allocate,
787+
1: tir_to_cs_translator.BufferType.scratch,
788788
3: tir_to_cs_translator.BufferType.input,
789789
4: tir_to_cs_translator.BufferType.output,
790790
}
@@ -804,21 +804,19 @@ def _check_buffer(address, region, length, buffer_var):
804804
constant_tensor_read_mask[address : address + length] = np.ones(
805805
length, dtype=buffer_dtype
806806
)
807-
elif buffer_type == tir_to_cs_translator.BufferType.runtime_allocate:
807+
elif buffer_type == tir_to_cs_translator.BufferType.scratch:
808808
shape = list(buffer_info[buffer_var].shape)
809809
assert length == np.prod(shape)
810-
assert address < runtime_allocation_size
810+
assert address < scratch_size
811811

812812
size_in_bytes = int(np.prod(shape)) * dtype_bytes
813813
# Every buffer is adjusted to align to 16 bytes
814814
size_in_bytes = util.round_up(size_in_bytes, 16)
815-
assert address + size_in_bytes <= runtime_allocation_size
816-
# The runtime_allocation area should not be used by anyother buffer
817-
assert not runtime_allocation_mask[address : address + size_in_bytes].any()
818-
# The runtime_allocation area is marked as used
819-
runtime_allocation_mask[address : address + size_in_bytes] = np.ones(
820-
size_in_bytes, dtype="uint8"
821-
)
815+
assert address + size_in_bytes <= scratch_size
816+
# The scratch area should not be used by any other buffer
817+
assert not scratch_mask[address : address + size_in_bytes].any()
818+
# The scratch area is marked as used
819+
scratch_mask[address : address + size_in_bytes] = np.ones(size_in_bytes, dtype="uint8")
822820
elif buffer_type == tir_to_cs_translator.BufferType.input:
823821
assert address == 0
824822
else:
@@ -898,13 +896,13 @@ def check_buffer(address, region, length, buffer_var):
898896
(
899897
_npu_ops,
900898
constant_hex_string,
901-
runtime_allocation_size,
899+
scratch_size,
902900
) = tir_to_cs_translator.assign_addresses(buffer_info, _npu_ops)
903-
runtime_allocation_mask = np.zeros(runtime_allocation_size, dtype="uint8")
901+
scratch_mask = np.zeros(scratch_size, dtype="uint8")
904902
constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8")
905903
verify(_npu_ops)
906-
# This will be only 1 if all allocated runtime_allocation is used.
907-
assert np.prod(runtime_allocation_mask) == 1
904+
# This will be only 1 if all allocated scratch is used.
905+
assert np.prod(scratch_mask) == 1
908906
# This will be only 1 if all constant tensors is read at least once.
909907
assert np.prod(constant_tensor_read_mask) == 1
910908

0 commit comments

Comments
 (0)