@@ -288,7 +288,7 @@ def _get_input_block(
288288 input_shape : _Shape ,
289289 dtype : str ,
290290 op_type : str ,
291- is_partkernel : bool ,
291+ partkernel : bool ,
292292 stride_h : int ,
293293 stride_w : int ,
294294 dilated_kernel_h : int ,
@@ -310,7 +310,7 @@ def _get_input_block(
310310
311311 if op_type == "ethosu_conv2d" :
312312 if dtype == "int8" :
313- if is_partkernel :
313+ if partkernel :
314314 depth = self ._align (min (32 , input_shape .depth ), 8 )
315315 else :
316316 depth = self ._align (min (16 , input_shape .depth ), 8 )
@@ -336,7 +336,7 @@ def get_kernel_steps(
336336 dilated_kernel_h : int ,
337337 dilated_kernel_w : int ,
338338 ifm_dtype : str ,
339- is_partkernel : bool = False ,
339+ partkernel : bool = False ,
340340 ) -> List [int ]:
341341 """Calculate the total number of subkernels and their sizes
342342
@@ -351,7 +351,7 @@ def get_kernel_steps(
351351 Width of dilated kernel
352352 ifm_dtype: str
353353 Datatype of the Input Feature Map tensor (IFM)
354- is_partkernel : bool
354+ partkernel : bool
355355 Flag showing whether part-kernel first traversal is used
356356
357357 Returns
@@ -368,7 +368,7 @@ def get_kernel_steps(
368368 kernel_steps = []
369369 for y , x in subkernels :
370370 subkernel_elements = x * y
371- if op_type == "ethosu_conv2d" and is_partkernel :
371+ if op_type == "ethosu_conv2d" and partkernel :
372372 # Part-kernel-first traversal conv2d
373373 divisor = 4 if ifm_dtype == "int8" else 2
374374 kernel_steps .append (int (_round_up_div (subkernel_elements , divisor )))
@@ -509,29 +509,31 @@ def get_elementwise_block_config(
509509 banks_available -= 2
510510
511511 # Split the block in half until it fits into SHRAM
512+ max_height , max_width , max_depth = self ._max_block_shape .as_list ()[1 :]
512513 if output_layout == "NHCWB16" :
513514 split_order = (a for a in [1 , 3 , 2 ])
514515 output_block = [
515516 output_shape [0 ],
516- min (output_shape [1 ], self ._max_block_shape .height ),
517- min (output_shape [2 ] * output_shape [4 ], self . _max_block_shape . depth ),
518- min (output_shape [3 ], self ._max_block_shape .width ),
517+ _round_up ( min (output_shape [1 ], max_height ), self ._micro_block .height ),
518+ min (output_shape [2 ] * output_shape [4 ], max_depth ),
519+ _round_up ( min (output_shape [3 ], max_width ), self ._micro_block .width ),
519520 16 ,
520521 ]
521522 else :
522523 split_order = (a for a in [1 , 2 , 3 ])
523524 output_block = [
524525 output_shape [0 ],
525- min (output_shape [1 ], self ._max_block_shape .height ),
526- min (output_shape [2 ], self ._max_block_shape .width ),
527- min (output_shape [3 ], self ._max_block_shape .depth ),
526+ _round_up ( min (output_shape [1 ], max_height ), self ._micro_block .height ),
527+ _round_up ( min (output_shape [2 ], max_width ), self ._micro_block .width ),
528+ _round_up ( min (output_shape [3 ], max_depth ), self ._micro_block .depth ),
528529 ]
529530 split_axis = next (split_order )
531+
532+ offset = [0 ] * len (output_block )
533+ stripes = [1 ] * len (output_block )
534+ order = [1 , 2 , 4 , 3 , 0 ] if output_layout == "NHCWB16" else [1 , 2 , 3 , 4 ]
530535 while True :
531536 # Create stripe config for output block
532- offset = [0 ] * len (output_block )
533- stripes = [1 ] * len (output_block )
534- order = [1 , 2 , 4 , 3 , 0 ] if output_layout == "NHCWB16" else [1 , 2 , 3 , 4 ]
535537 output_stripe_config = StripeConfig (
536538 output_block , output_block , output_block , order , stripes , offset
537539 )
@@ -564,10 +566,12 @@ def get_elementwise_block_config(
564566 block_config .append (BlockConfig (output_block , output_block , 0 , output_cycles ))
565567 break
566568
567- if output_block [split_axis ] == 1 :
569+ if output_block [split_axis ] == self . _micro_block . as_list ()[ split_axis ] :
568570 split_axis = next (split_order )
569571
570- output_block [split_axis ] = _round_up_div (output_block [split_axis ], 2 )
572+ output_block [split_axis ] = _round_up (
573+ _round_up_div (output_block [split_axis ], 2 ), self ._micro_block .as_list ()[split_axis ]
574+ )
571575
572576 return block_config
573577
@@ -670,9 +674,9 @@ def get_valid_block_configs(
670674
671675 # Input block depth has additional limitations for operators that require full input depth
672676 input_block_depth = 0
673- is_partkernel = self .is_partkernel (op_type , ifm_channels , ifm_dtype , kernel_h * kernel_w )
677+ partkernel = self .is_partkernel (op_type , ifm_channels , ifm_dtype , kernel_h * kernel_w )
674678 if op_type == "ethosu_conv2d" :
675- if is_partkernel :
679+ if partkernel :
676680 input_block_depth = min (ifm_channels , 16 )
677681 else :
678682 input_block_depth = min (ifm_channels , 32 )
@@ -745,7 +749,8 @@ def get_valid_block_configs(
745749 kernel_h ,
746750 kernel_w ,
747751 ifm_channels ,
748- is_partkernel ,
752+ "int8" ,
753+ partkernel ,
749754 )
750755 block_config = BlockConfig (
751756 input_block_shape .as_list (), output_block , compute_cycles , output_cycles
@@ -767,15 +772,15 @@ def _estimate_compute_cycles_per_block(
767772 kernel_w : int ,
768773 input_channels : int ,
769774 ifm_dtype : str ,
770- is_partkernel : bool = False ,
775+ partkernel : bool = False ,
771776 ) -> Tuple [int , int ]:
772777 # Calculate the amount of micro blocks per block, per axis
773778 num_quantum_x = _round_up_div (block_shape .width , self ._micro_block .width )
774779 num_quantum_y = _round_up_div (block_shape .height , self ._micro_block .height )
775780 num_quantum_z = _round_up_div (block_shape .depth , self ._micro_block .depth )
776781 num_quantum_xy = num_quantum_x * num_quantum_y
777782
778- kernel_steps = self .get_kernel_steps (op_type , kernel_h , kernel_w , ifm_dtype , is_partkernel )
783+ kernel_steps = self .get_kernel_steps (op_type , kernel_h , kernel_w , ifm_dtype , partkernel )
779784
780785 wd_cycles = self ._get_weight_decoder_cycles (op_type )
781786 delay_cycles = self ._get_delay_cycles (op_type , ifm_dtype )
@@ -794,7 +799,7 @@ def _estimate_compute_cycles_per_block(
794799 elif subkernel_steps > 1 :
795800 compute_cycles += delay_cycles * (subkernel_steps - 1 ) * num_quantum_z
796801
797- if is_partkernel :
802+ if partkernel :
798803 compute_cycles *= _round_up_div (input_block_shape .depth , 8 )
799804
800805 if op_type == "ethosu_conv2d" :
0 commit comments