@@ -36,15 +36,15 @@ class BufferType(Enum):
3636
3737 constant = auto ()
3838 input_or_output = auto ()
39- runtime_allocate = auto ()
39+ scratch = auto ()
4040 input = auto ()
4141 output = auto ()
4242 shram = auto ()
4343
4444
4545_REGION_MAP = {
4646 BufferType .constant : 0 ,
47- BufferType .runtime_allocate : 1 ,
47+ BufferType .scratch : 1 ,
4848 BufferType .input : 3 ,
4949 BufferType .output : 4 ,
5050 BufferType .shram : int ((1 << 8 ) | (3 << 0 )),
@@ -103,23 +103,23 @@ def translate(tir_module, params):
103103 An hex string of the bytes that includes concat'd
104104 encoded weights, encoded biases and scales.
105105 base_addresses : List[util.BaseAddress]
106- base addresses
106+ base addresses to be used by the driver
107107 """
108108
109109 buffer_info = extract_buffer_info (tir_module , params )
110110 call_extern_list = extract_call_extern_list (tir_module )
111111 _npu_ops = list ()
112112 for call_extern in call_extern_list :
113113 _npu_ops .append (translate_ethosu_tir_call_extern (call_extern ))
114- _npu_ops , constant_data , runtime_allocation_size = assign_addresses (buffer_info , _npu_ops )
114+ _npu_ops , constant_data , scratch_size = assign_addresses (buffer_info , _npu_ops )
115115 base_addresses = extract_param_base_addresses (tir_module , buffer_info )
116- if runtime_allocation_size > 0 :
116+ if scratch_size > 0 :
117117 base_addresses .append (
118118 util .BaseAddress (
119- "runtime_allocation " ,
119+ "scratch " ,
120120 None ,
121- _REGION_MAP [BufferType .runtime_allocate ],
122- runtime_allocation_size ,
121+ _REGION_MAP [BufferType .scratch ],
122+ scratch_size ,
123123 True ,
124124 )
125125 )
@@ -248,7 +248,7 @@ def populate_allocate_buffer_info(stmt):
248248 if storage_scope == "local" :
249249 buffer_type = BufferType .shram
250250 else :
251- buffer_type = BufferType .runtime_allocate
251+ buffer_type = BufferType .scratch
252252 buffer_info [allocate .buffer_var ] = BufferInfo (
253253 None ,
254254 allocate .extents ,
@@ -280,7 +280,7 @@ def assign_addresses(buffer_info, npu_ops):
280280 A list of Vela NpuOps with addesses within scratch and constant buffers
281281 constant_tensor : NDArray
282282 A unified constant data array of uint8 as the constant buffer
283- runtime_allocation_size : int
283+ scratch_size : int
284284 The size of the scratch tensor.
285285 """
286286
@@ -327,7 +327,7 @@ def classify_io(buffer):
327327
328328 raise ValueError (f"Unused IO : { buffer } in tir module." )
329329
330- runtime_allocation_size = 0
330+ scratch_size = 0
331331 constant_hex_data = []
332332 total_constant_len = 0
333333 buffer_addresses = dict ()
@@ -352,7 +352,9 @@ def classify_io(buffer):
352352 assert buffer_type in (BufferType .input , BufferType .output )
353353 address = 0
354354 buffer_addresses [_buffer ] = (address , buffer_type )
355- buffer_info [_buffer ] = BufferInfo (None , info .dtype , info .dtype , buffer_type )
355+ buffer_info [_buffer ] = BufferInfo (
356+ values = None , shape = info .dtype , dtype = info .dtype , btype = buffer_type
357+ )
356358 elif info .btype == BufferType .shram :
357359 accl_config = util .get_accelerator_config ()
358360 arch_config = get_accelerator_arch_config (accl_config )
@@ -363,9 +365,9 @@ def classify_io(buffer):
363365 size_in_bytes = int (dtype_bytes * np .prod (list (info .shape )))
364366 # Every memory address the NPU access have to be 16 byte aligned
365367 size_in_bytes = util .round_up (size_in_bytes , 16 )
366- assert info .btype == BufferType .runtime_allocate
367- address = runtime_allocation_size
368- runtime_allocation_size += size_in_bytes
368+ assert info .btype == BufferType .scratch
369+ address = scratch_size
370+ scratch_size += size_in_bytes
369371 buffer_addresses [_buffer ] = (address , info .btype )
370372
371373 for npu_op in npu_ops :
@@ -382,7 +384,7 @@ def classify_io(buffer):
382384 return (
383385 npu_ops ,
384386 constant_data ,
385- runtime_allocation_size ,
387+ scratch_size ,
386388 )
387389
388390
0 commit comments