Skip to content

Commit e36f4bf

Browse files
committed
[microNPU] Refactor base address determination to codegen
This commit introduces BaseAddress ObjectRef to determine base addresses in the codegen for microNPU. This is required when multiple memory pools become available. Thus, base addresses could not be statically determined in the source module. Change-Id: I6cfa578af0318bbe07d3bb9415df7bdd839611d3
1 parent 0d2340c commit e36f4bf

File tree

8 files changed

+274
-137
lines changed

8 files changed

+274
-137
lines changed

python/tvm/relay/backend/contrib/ethosu/codegen.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
289289
This returns the scheduled PrimFunc
290290
"""
291291
assert len(ext_func.params) == 1
292-
input_size = util.calculate_size_bytes(ext_func.params[0])
293-
output_size = util.calculate_size_bytes(ext_func.body)
294292
mod = tvm.IRModule()
295293
mod["main"] = ext_func
296294
mod = LegalizeEthosU()(mod)
@@ -309,8 +307,6 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
309307
primfunc = tir_mod["main"]
310308
primfunc = primfunc.with_attr("global_symbol", ext_func.attrs["global_symbol"])
311309
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
312-
primfunc = primfunc.with_attr("ethos-u.input_size", input_size)
313-
primfunc = primfunc.with_attr("ethos-u.output_size", output_size)
314310
return primfunc
315311

316312

@@ -334,18 +330,14 @@ def primfunc_to_artifact(primfunc: tvm.tir.PrimFunc) -> util.CompilationArtifact
334330
"""
335331
symbol = str(primfunc.attrs["global_symbol"])
336332
const_dict = primfunc.attrs["ethos-u.constants"]
337-
input_size = primfunc.attrs["ethos-u.input_size"]
338-
output_size = primfunc.attrs["ethos-u.output_size"]
339333
tir_mod = tvm.IRModule()
340334
tir_mod[symbol] = primfunc
341335

342336
const_dict_with_int_keys = dict()
343337
for idx in const_dict.keys():
344338
const_dict_with_int_keys[int(idx)] = const_dict[idx].numpy()
345339

346-
cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(
340+
cmms, encoded_constants, base_addresses = tir_to_cs_translator.translate(
347341
tir_mod, const_dict_with_int_keys
348342
)
349-
return util.CompilationArtifact(
350-
cmms, encoded_constants, scratch_size, input_size, output_size, symbol
351-
)
343+
return util.CompilationArtifact(symbol, cmms, encoded_constants, base_addresses)

python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
the Relay to TIR compilation process, to Vela API calls to
1919
generate command stream.
2020
"""
21-
from typing import Dict, NamedTuple, Tuple, Union
21+
from typing import Dict, NamedTuple, Tuple, Union, List
2222
from enum import auto
2323
from enum import Enum
2424
import numpy as np # type: ignore
@@ -36,15 +36,15 @@ class BufferType(Enum):
3636

3737
constant = auto()
3838
input_or_output = auto()
39-
scratch = auto()
39+
runtime_allocate = auto()
4040
input = auto()
4141
output = auto()
4242
shram = auto()
4343

4444

4545
_REGION_MAP = {
4646
BufferType.constant: 0,
47-
BufferType.scratch: 1,
47+
BufferType.runtime_allocate: 1,
4848
BufferType.input: 3,
4949
BufferType.output: 4,
5050
BufferType.shram: int((1 << 8) | (3 << 0)),
@@ -101,20 +101,70 @@ def translate(tir_module, params):
101101
encoded_constants : str
102102
An hex string of the bytes that includes concat'd
103103
encoded weights, encoded biases and scales.
104-
scratch_size : int
105-
The size of the scratch buffer needed.
104+
base_addresses : Dict[int, int]
105+
The pool params of the PrimFuncs and their regions in the driver
106106
"""
107107

108108
buffer_info = extract_buffer_info(tir_module, params)
109109
call_extern_list = extract_call_extern_list(tir_module)
110110
_npu_ops = list()
111111
for call_extern in call_extern_list:
112112
_npu_ops.append(translate_ethosu_tir_call_extern(call_extern))
113-
_npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops)
113+
_npu_ops, constant_data, runtime_allocation_size = assign_addresses(buffer_info, _npu_ops)
114+
base_addresses = extract_param_base_addresses(tir_module, buffer_info)
115+
if runtime_allocation_size > 0:
116+
base_addresses.append(
117+
util.BaseAddress(
118+
"runtime_allocation",
119+
None,
120+
_REGION_MAP[BufferType.runtime_allocate],
121+
runtime_allocation_size,
122+
True,
123+
)
124+
)
114125
target_accel_config = vela_api.get_accelerator_config()
115126
cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config)
116127
payload = vapi.npu_create_driver_payload(cmds, target_accel_config)
117-
return payload.hex(), constant_data, scratch_size
128+
return payload.hex(), constant_data, base_addresses
129+
130+
131+
def extract_param_base_addresses(mod, buffer_info) -> List[util.BaseAddress]:
132+
"""This function extracts pool param to regions in the driver
133+
134+
Parameters
135+
----------
136+
mod : tvm.IRModule
137+
The TIR Module for NPU
138+
buffer_info : Dict[tvm.tir.Var, BufferInfo]
139+
Information regarding buffer vars used in the PrimFunc
140+
141+
Returns
142+
-------
143+
Dict[int, int]
144+
Each key is a pool param to the PrimFunc and the value is Region for the driver
145+
"""
146+
# There should only be a single function
147+
assert len(mod.functions.items()) == 1
148+
primfunc = mod.functions.items()[0][1]
149+
150+
base_addresses = list()
151+
idx = 0
152+
for param in primfunc.params:
153+
# constants are pooled together and handled specially
154+
# this will change after tir.allocate_const.
155+
# For now, we are skipping generating buffer addresses here
156+
if buffer_info[param].btype == BufferType.constant:
157+
continue
158+
buffer = primfunc.buffer_map[param]
159+
dtype = buffer.dtype
160+
element_size_bytes = np.iinfo(dtype).bits // 8
161+
size_bytes = element_size_bytes * np.prod(list(buffer.shape))
162+
base_addresses.append(
163+
util.BaseAddress(param.name, idx, _REGION_MAP[buffer_info[param].btype], size_bytes)
164+
)
165+
idx += 1
166+
167+
return base_addresses
118168

119169

120170
def extract_call_extern_list(mod):
@@ -170,6 +220,7 @@ def extract_buffer_info(
170220
# There should only be a single function
171221
assert len(mod.functions.items()) == 1
172222
primfunc = mod.functions.items()[0][1]
223+
173224
for idx, const_data in param_dict.items():
174225
param = primfunc.params[idx]
175226
buffer_info[param] = BufferInfo(
@@ -196,7 +247,7 @@ def populate_allocate_buffer_info(stmt):
196247
if storage_scope == "local":
197248
buffer_type = BufferType.shram
198249
else:
199-
buffer_type = BufferType.scratch
250+
buffer_type = BufferType.runtime_allocate
200251
buffer_info[allocate.buffer_var] = BufferInfo(
201252
None,
202253
allocate.extents,
@@ -228,7 +279,7 @@ def assign_addresses(buffer_info, npu_ops):
228279
A list of Vela NpuOps with addesses within scratch and constant buffers
229280
constant_tensor : NDArray
230281
A unified constant data array of uint8 as the constant buffer
231-
scratch_size : int
282+
runtime_allocation_size : int
232283
The size of the scratch tensor.
233284
"""
234285

@@ -275,7 +326,7 @@ def classify_io(buffer):
275326

276327
raise ValueError(f"Unused IO : {buffer} in tir module.")
277328

278-
scratch_size = 0
329+
runtime_allocation_size = 0
279330
constant_hex_data = []
280331
total_constant_len = 0
281332
buffer_addresses = dict()
@@ -300,6 +351,7 @@ def classify_io(buffer):
300351
assert buffer_type in (BufferType.input, BufferType.output)
301352
address = 0
302353
buffer_addresses[_buffer] = (address, buffer_type)
354+
buffer_info[_buffer] = BufferInfo(None, info.dtype, info.dtype, buffer_type)
303355
elif info.btype == BufferType.shram:
304356
accl_config = util.get_accelerator_config()
305357
arch_config = get_accelerator_arch_config(accl_config)
@@ -310,9 +362,9 @@ def classify_io(buffer):
310362
size_in_bytes = int(dtype_bytes * np.prod(list(info.shape)))
311363
# Every memory address the NPU access have to be 16 byte aligned
312364
size_in_bytes = util.round_up(size_in_bytes, 16)
313-
assert info.btype == BufferType.scratch
314-
address = scratch_size
315-
scratch_size += size_in_bytes
365+
assert info.btype == BufferType.runtime_allocate
366+
address = runtime_allocation_size
367+
runtime_allocation_size += size_in_bytes
316368
buffer_addresses[_buffer] = (address, info.btype)
317369

318370
for npu_op in npu_ops:
@@ -329,7 +381,7 @@ def classify_io(buffer):
329381
return (
330382
npu_ops,
331383
constant_data,
332-
scratch_size,
384+
runtime_allocation_size,
333385
)
334386

335387

python/tvm/relay/backend/contrib/ethosu/util.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from inspect import signature
2525
from enum import Enum
26-
from typing import Union, Tuple
26+
from typing import Union, Tuple, List
2727
import numpy as np # type: ignore
2828

2929
import tvm # type: ignore
@@ -239,6 +239,31 @@ def calculate_size_bytes(expr):
239239
return element_size * elements
240240

241241

242+
@register_object("relay.ext.ethos-u.BaseAddress")
243+
class BaseAddress(Object):
244+
"""
245+
This is a structure to hold base addresses for pointers
246+
provided for the driver.
247+
"""
248+
249+
def __init__(
250+
self,
251+
name: str,
252+
primfunc_param_idx: int,
253+
region: int,
254+
size: int,
255+
is_runtime_allocation: bool = False,
256+
):
257+
self.__init_handle_by_constructor__(
258+
_ffi_api.BaseAddress, # type: ignore # pylint: disable=no-member
259+
name,
260+
primfunc_param_idx,
261+
region,
262+
size,
263+
is_runtime_allocation,
264+
)
265+
266+
242267
@register_object("relay.ext.ethos-u.CompilationArtifact")
243268
class CompilationArtifact(Object):
244269
"""
@@ -248,19 +273,15 @@ class CompilationArtifact(Object):
248273

249274
def __init__(
250275
self,
276+
function_name: str,
251277
command_stream: str,
252278
encoded_constants: str,
253-
scratch_size: int,
254-
input_size: int,
255-
output_size: int,
256-
function_name: str,
279+
base_addresses: List[BaseAddress],
257280
):
258281
self.__init_handle_by_constructor__(
259282
_ffi_api.CompilationArtifact, # type: ignore # pylint: disable=no-member
283+
function_name,
260284
command_stream,
261285
encoded_constants,
262-
scratch_size,
263-
input_size,
264-
output_size,
265-
function_name,
286+
base_addresses,
266287
)

0 commit comments

Comments
 (0)