Skip to content

Commit 0a4cc89

Browse files
authored
[microNPU] Support binary elementwise with non-4D inputs (#9521)
Reshapes non-4D inputs to become 4D, then reshapes the output back to the non-4D input shape.
1 parent c3f5271 commit 0a4cc89

File tree

4 files changed

+173
-14
lines changed

4 files changed

+173
-14
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,60 @@ def __init__(
426426
self.params_class = params_class
427427
self.pattern = pattern
428428

429+
@staticmethod
430+
def reshape_input(
431+
inputs: List["TensorParams"],
432+
) -> List[tvm.relay.Expr]:
433+
"""Reshape the inputs so that the following binary elementwise
434+
operator receives 4-dimensional inputs.
435+
436+
Parameters
437+
----------
438+
inputs: List[TensorParams]
439+
The inputs to reshape.
440+
441+
Returns
442+
-------
443+
reshaped_inputs: List[tvm.relay.Expr]
444+
The new reshaped inputs.
445+
"""
446+
reshaped_inputs = []
447+
for i in inputs:
448+
in_shape = i.shape
449+
if len(in_shape) < 4:
450+
pad_size = 4 - len(in_shape)
451+
new_shape = ([1] * pad_size) + in_shape
452+
new_call = relay.reshape(i.tensor, new_shape)
453+
reshaped_inputs.append(new_call)
454+
else:
455+
reshaped_inputs.append(i.tensor)
456+
return reshaped_inputs
457+
458+
@staticmethod
459+
def reshape_output(output: tvm.relay.Expr, ifm_input_shape: List[int]) -> tvm.relay.Expr:
460+
"""Reshape the output back to the original dimensionality.
461+
Since the NPU must have the brodcastable tensor as the
462+
second operand, the original shape of the first ifm must
463+
be the output shape.
464+
465+
Parameters
466+
----------
467+
output: tvm.relay.Expr
468+
The output to reshape.
469+
470+
ifm_input_shape: List[int]
471+
The shape of the non-reshaped ifm tensor.
472+
473+
Returns
474+
-------
475+
reshaped_output: tvm.relay.Expr
476+
The reshaped output expression.
477+
"""
478+
if len(ifm_input_shape) == 4:
479+
return output
480+
reshaped_output = relay.reshape(output, ifm_input_shape)
481+
return reshaped_output
482+
429483
def callback(
430484
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
431485
) -> tvm.relay.Expr:
@@ -451,9 +505,12 @@ def callback(
451505
# We don't yet support activation functions that need to get legalized to LUTs.
452506
lut = relay.const([], dtype="int8")
453507

454-
return ethosu_ops.ethosu_binary_elementwise(
455-
ifm=params.ifm.tensor,
456-
ifm2=params.ifm2.tensor,
508+
inputs = [params.ifm, params.ifm2]
509+
inputs = self.reshape_input(inputs)
510+
511+
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
512+
ifm=inputs[0],
513+
ifm2=inputs[1],
457514
lut=lut,
458515
operator_type=params.operator_type,
459516
ifm_scale=float(params.ifm.q_params.scale_f32),
@@ -462,8 +519,8 @@ def callback(
462519
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
463520
ofm_scale=float(params.ofm.q_params.scale_f32),
464521
ofm_zero_point=int(params.ofm.q_params.zero_point),
465-
ifm_channels=params.ifm.shape[3],
466-
ifm2_channels=params.ifm2.shape[3],
522+
ifm_channels=params.ifm.shape[-1],
523+
ifm2_channels=params.ifm2.shape[-1],
467524
reversed_operands=params.reversed_operands,
468525
ofm_dtype=params.ofm.dtype,
469526
activation=activation,
@@ -473,6 +530,8 @@ def callback(
473530
ifm2_layout=str(params.ifm2.layout),
474531
ofm_layout=str(params.ofm.layout),
475532
)
533+
output = self.reshape_output(ethosu_binary_elementwise, params.ifm.shape)
534+
return output
476535

477536

478537
class AddRewriter(BinaryElementwiseRewriter):

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,12 @@ def __init__(self, func_body: Call, operator_type: str, has_quantization_paramet
516516
self.activation = clip
517517
self.operator_type = operator_type
518518

519-
def can_broadcast(x, y):
520-
for i in range(1, 4):
521-
if x.shape[i] == y.shape[i] or y.shape[i] == 1:
522-
continue
519+
def can_broadcast(ifm, ifm2):
520+
if len(ifm.shape) < len(ifm2.shape):
523521
return False
522+
for m, n in zip(ifm.shape[::-1], ifm2.shape[::-1]):
523+
if m != n and m == 1:
524+
return False
524525
return True
525526

526527
if can_broadcast(self.ifm, self.ifm2):
@@ -539,9 +540,14 @@ def is_valid(self):
539540
"""
540541
if np.dtype(self.ofm) == np.int32 and self.activation is not None:
541542
return False
542-
if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
543+
# Due to identity operator requiring ofm != int32 for now
544+
if np.dtype(self.ofm) == np.int32 and len(self.ofm.shape) < 4:
543545
return False
544-
if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
546+
if len(self.ifm.shape) > 4 or len(self.ifm2.shape) > 4:
547+
return False
548+
if len(self.ifm.shape) == 4 and self.ifm.shape[0] != 1:
549+
return False
550+
if len(self.ifm2.shape) == 4 and self.ifm2.shape[0] != 1:
545551
return False
546552
if not self.valid_broadcast:
547553
return False

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def representative_dataset():
349349
([1, 2, 3, 4], [1, 2, 3, 4]),
350350
([1, 2, 3, 4], [1, 1, 1, 1]),
351351
([1, 1, 1, 1], [1, 2, 3, 4]),
352+
([1, 4, 4], [4, 1]),
352353
],
353354
)
354355
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
@@ -435,6 +436,84 @@ def representative_dataset():
435436
infra.verify_source(compiled_models, accel_type)
436437

437438

439+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
440+
@pytest.mark.parametrize(
441+
"ifm_shape, ifm2_shape",
442+
[
443+
([4], [4]),
444+
([4], [1, 2, 3, 4]),
445+
([1, 4, 4], [4, 1]),
446+
],
447+
)
448+
def test_binary_add_with_non_4d_shapes(
449+
accel_type,
450+
ifm_shape,
451+
ifm2_shape,
452+
):
453+
dtype = "int8"
454+
455+
def create_tflite_graph():
456+
class Model(tf.Module):
457+
@tf.function
458+
def tf_function(self, lhs, rhs):
459+
return tf.math.add(lhs, rhs)
460+
461+
model = Model()
462+
concrete_func = model.tf_function.get_concrete_function(
463+
tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32)
464+
)
465+
466+
# Convert the model
467+
def representative_dataset():
468+
for _ in range(100):
469+
data = np.random.rand(*tuple(ifm_shape))
470+
data2 = np.random.rand(*tuple(ifm2_shape)) * 2
471+
yield [data.astype(np.float32), data2.astype(np.float32)]
472+
473+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
474+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
475+
converter.representative_dataset = representative_dataset
476+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
477+
converter.inference_input_type = tf.int8
478+
converter.inference_output_type = tf.int8
479+
tflite_model = converter.convert()
480+
return tflite_model
481+
482+
tflite_graph = create_tflite_graph()
483+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
484+
485+
mod, params = relay.frontend.from_tflite(
486+
tflite_model,
487+
shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
488+
dtype_dict={"ifm": dtype, "ifm2": dtype},
489+
)
490+
mod = partition_for_ethosu(mod, params)
491+
492+
# Generate reference data
493+
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
494+
495+
compiled_models = infra.build_source(
496+
mod,
497+
input_data,
498+
output_data,
499+
accel_type,
500+
output_tolerance=0,
501+
)
502+
503+
# Assumes only two runtime.Modules are created -- i.e. single offload module
504+
imported_modules = compiled_models[0].executor_factory.lib.imported_modules
505+
assert len(imported_modules) == 2
506+
ethosu_module = imported_modules[0]
507+
508+
# Verify generated C source
509+
get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
510+
cmms = get_cs(ethosu_module)
511+
cmms = bytes.fromhex(cmms)
512+
513+
infra.print_payload(cmms)
514+
infra.verify_source(compiled_models, accel_type)
515+
516+
438517
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
439518
def test_binary_add_from_constant_scalar(accel_type):
440519
dtype = "uint8"

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,10 @@ def verify(ext_func):
565565
([1, 2, 3, 4], [1, 2, 3, 4], False),
566566
([1, 2, 3, 4], [1, 1, 3, 1], False),
567567
([1, 1, 3, 1], [1, 2, 3, 4], True),
568+
([1, 4, 4], [4, 1], False),
569+
([4], [4], False),
570+
([4], [1, 2, 3, 4], True),
571+
([1, 4, 4], [4, 1], False),
568572
],
569573
)
570574
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
@@ -621,16 +625,27 @@ def verify(ext_func):
621625
shapes = [ifm_shape, ifm2_shape]
622626
ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
623627
op = ext_func.body
624-
assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
625-
assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
628+
629+
has_reshaped_output = False
630+
shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
631+
out_padded = [1] * (4 - len(out_shape)) + out_shape
632+
if op.op.name != "contrib.ethosu.binary_elementwise":
633+
has_reshaped_output = True
634+
op = op.args[0]
635+
636+
assert list(op.args[0].checked_type.shape) == shapes_padded[ifm_index]
637+
assert list(op.args[1].checked_type.shape) == shapes_padded[ifm2_index]
626638
assert op.args[0].checked_type.dtype == dtype
627-
assert list(op.checked_type.shape) == out_shape
639+
assert list(op.checked_type.shape) == out_padded
628640
assert op.checked_type.dtype == dtype
629641
assert op.attrs.operator_type == operator_type
630642
assert op.attrs.reversed_operands == reversed_operands
631643
if activation_function == "RELU":
632644
assert str(op.attrs.activation) == "CLIP"
633645

646+
if has_reshaped_output:
647+
assert list(ext_func.body.checked_type.shape) == out_shape
648+
634649
if operator_type == "ADD":
635650
rewriter = legalize.AddRewriter()
636651
pattern_table = [

0 commit comments

Comments
 (0)