1818the Relay to TIR compilation process, to Vela API calls to
1919generate command stream.
2020"""
21- from typing import Dict , NamedTuple , Tuple , Union
21+ from typing import Dict , NamedTuple , Tuple , Union , List
2222from enum import auto
2323from enum import Enum
2424import numpy as np # type: ignore
@@ -36,15 +36,15 @@ class BufferType(Enum):
3636
3737 constant = auto ()
3838 input_or_output = auto ()
39- scratch = auto ()
39+ runtime_allocate = auto ()
4040 input = auto ()
4141 output = auto ()
4242 shram = auto ()
4343
4444
4545_REGION_MAP = {
4646 BufferType .constant : 0 ,
47- BufferType .scratch : 1 ,
47+ BufferType .runtime_allocate : 1 ,
4848 BufferType .input : 3 ,
4949 BufferType .output : 4 ,
5050 BufferType .shram : int ((1 << 8 ) | (3 << 0 )),
@@ -101,20 +101,70 @@ def translate(tir_module, params):
101101 encoded_constants : str
102102 An hex string of the bytes that includes concat'd
103103 encoded weights, encoded biases and scales.
104- scratch_size : int
105- The size of the scratch buffer needed.
104+ base_addresses : Dict[ int, int]
105+ The pool params of the PrimFuncs and their regions in the driver
106106 """
107107
108108 buffer_info = extract_buffer_info (tir_module , params )
109109 call_extern_list = extract_call_extern_list (tir_module )
110110 _npu_ops = list ()
111111 for call_extern in call_extern_list :
112112 _npu_ops .append (translate_ethosu_tir_call_extern (call_extern ))
113- _npu_ops , constant_data , scratch_size = assign_addresses (buffer_info , _npu_ops )
113+ _npu_ops , constant_data , runtime_allocation_size = assign_addresses (buffer_info , _npu_ops )
114+ base_addresses = extract_param_base_addresses (tir_module , buffer_info )
115+ if runtime_allocation_size > 0 :
116+ base_addresses .append (
117+ util .BaseAddress (
118+ "runtime_allocation" ,
119+ None ,
120+ _REGION_MAP [BufferType .runtime_allocate ],
121+ runtime_allocation_size ,
122+ True ,
123+ )
124+ )
114125 target_accel_config = vela_api .get_accelerator_config ()
115126 cmds = vapi .npu_generate_register_command_stream (_npu_ops , target_accel_config )
116127 payload = vapi .npu_create_driver_payload (cmds , target_accel_config )
117- return payload .hex (), constant_data , scratch_size
128+ return payload .hex (), constant_data , base_addresses
129+
130+
131+ def extract_param_base_addresses (mod , buffer_info ) -> List [util .BaseAddress ]:
132+ """This function extracts pool param to regions in the driver
133+
134+ Parameters
135+ ----------
136+ mod : tvm.IRModule
137+ The TIR Module for NPU
138+ buffer_info : Dict[tvm.tir.Var, BufferInfo]
139+ Information regarding buffer vars used in the PrimFunc
140+
141+ Returns
142+ -------
143+ Dict[int, int]
144+ Each key is a pool param to the PrimFunc and the value is Region for the driver
145+ """
146+ # There should only be a single function
147+ assert len (mod .functions .items ()) == 1
148+ primfunc = mod .functions .items ()[0 ][1 ]
149+
150+ base_addresses = list ()
151+ idx = 0
152+ for param in primfunc .params :
153+ # constants are pooled together and handled specially
154+ # this will change after tir.allocate_const.
155+ # For now, we are skipping generating buffer addresses here
156+ if buffer_info [param ].btype == BufferType .constant :
157+ continue
158+ buffer = primfunc .buffer_map [param ]
159+ dtype = buffer .dtype
160+ element_size_bytes = np .iinfo (dtype ).bits // 8
161+ size_bytes = element_size_bytes * np .prod (list (buffer .shape ))
162+ base_addresses .append (
163+ util .BaseAddress (param .name , idx , _REGION_MAP [buffer_info [param ].btype ], size_bytes )
164+ )
165+ idx += 1
166+
167+ return base_addresses
118168
119169
120170def extract_call_extern_list (mod ):
@@ -170,6 +220,7 @@ def extract_buffer_info(
170220 # There should only be a single function
171221 assert len (mod .functions .items ()) == 1
172222 primfunc = mod .functions .items ()[0 ][1 ]
223+
173224 for idx , const_data in param_dict .items ():
174225 param = primfunc .params [idx ]
175226 buffer_info [param ] = BufferInfo (
@@ -196,7 +247,7 @@ def populate_allocate_buffer_info(stmt):
196247 if storage_scope == "local" :
197248 buffer_type = BufferType .shram
198249 else :
199- buffer_type = BufferType .scratch
250+ buffer_type = BufferType .runtime_allocate
200251 buffer_info [allocate .buffer_var ] = BufferInfo (
201252 None ,
202253 allocate .extents ,
@@ -228,7 +279,7 @@ def assign_addresses(buffer_info, npu_ops):
228279 A list of Vela NpuOps with addesses within scratch and constant buffers
229280 constant_tensor : NDArray
230281 A unified constant data array of uint8 as the constant buffer
231- scratch_size : int
282+ runtime_allocation_size : int
232283 The size of the scratch tensor.
233284 """
234285
@@ -275,7 +326,7 @@ def classify_io(buffer):
275326
276327 raise ValueError (f"Unused IO : { buffer } in tir module." )
277328
278- scratch_size = 0
329+ runtime_allocation_size = 0
279330 constant_hex_data = []
280331 total_constant_len = 0
281332 buffer_addresses = dict ()
@@ -300,6 +351,7 @@ def classify_io(buffer):
300351 assert buffer_type in (BufferType .input , BufferType .output )
301352 address = 0
302353 buffer_addresses [_buffer ] = (address , buffer_type )
354+ buffer_info [_buffer ] = BufferInfo (None , info .dtype , info .dtype , buffer_type )
303355 elif info .btype == BufferType .shram :
304356 accl_config = util .get_accelerator_config ()
305357 arch_config = get_accelerator_arch_config (accl_config )
@@ -310,9 +362,9 @@ def classify_io(buffer):
310362 size_in_bytes = int (dtype_bytes * np .prod (list (info .shape )))
311363 # Every memory address the NPU access have to be 16 byte aligned
312364 size_in_bytes = util .round_up (size_in_bytes , 16 )
313- assert info .btype == BufferType .scratch
314- address = scratch_size
315- scratch_size += size_in_bytes
365+ assert info .btype == BufferType .runtime_allocate
366+ address = runtime_allocation_size
367+ runtime_allocation_size += size_in_bytes
316368 buffer_addresses [_buffer ] = (address , info .btype )
317369
318370 for npu_op in npu_ops :
@@ -329,7 +381,7 @@ def classify_io(buffer):
329381 return (
330382 npu_ops ,
331383 constant_data ,
332- scratch_size ,
384+ runtime_allocation_size ,
333385 )
334386
335387
0 commit comments