diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index cbd84aa3d350..8f5d6c24f0f6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -652,8 +652,8 @@ def callback( ifm2_zero_point=int(params.ifm2.q_params.zero_point), ofm_scale=float(params.ofm.q_params.scale_f32), ofm_zero_point=int(params.ofm.q_params.zero_point), - ifm_channels=params.ifm.shape[-1], - ifm2_channels=params.ifm2.shape[-1], + ifm_channels=params.ifm.shape[-1] if params.ifm.shape else 1, + ifm2_channels=params.ifm2.shape[-1] if params.ifm2.shape else 1, reversed_operands=params.reversed_operands, ofm_dtype=params.ofm.dtype, activation=activation, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c5bfa5cf92ef..bcd785ddbbd8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -123,12 +123,11 @@ def __init__(self): def visit_constant(self, const): if isinstance(const.checked_type, relay.ty.TensorType): - if const.checked_type.concrete_shape != (): - self.constants.append(const.data.asnumpy()) - name = "p" + str(len(self.constants)) - var = relay.var(type_annotation=const.checked_type, name_hint=name) - self.const_vars.append(var) - return var + self.constants.append(const.data.asnumpy()) + name = "p" + str(len(self.constants)) + var = relay.var(type_annotation=const.checked_type, name_hint=name) + self.const_vars.append(var) + return var return const diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 572057452602..20a8ff85ee2f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -136,7 +136,10 @@ def _visit(tensor, reader, lut): if tensor not in planned: planned.add(tensor) if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut: - index = list(cached_func.inputs).index(tensor) + # Find index of input using 'same_as' check to prevent equality + # ambiguity when encountering a scalar. + is_same = [var.same_as(tensor) for var in cached_func.inputs] + index = is_same.index(True) if index in const_dict: sch.cache_read(tensor, "global", [reader]) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 23fd74dc486d..4042bb057bd3 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -629,12 +629,13 @@ def create_mod_from_relay(): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("dtype", ["int8", "uint8"]) -def test_elementwise_add_from_constant_scalar(accel_type, dtype): +@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)]) +def test_elementwise_add_from_constant_scalar(accel_type, dtype, constant): ifm_shape = (1, 4, 4, 8) def create_relay_graph(): inp = relay.var("input", shape=ifm_shape, dtype=dtype) - scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + scalar = relay.const(constant, dtype=dtype) add = relay.qnn.op.add( inp, scalar, diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index 4df6311a230c..e1688b8aa512 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -34,8 +34,7 @@ def test_lower_to_tir(): kernel_layout="HWIO", out_dtype="int32", ) - multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) - tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + tile = relay.tile(p2, reps=(1, 1, 1, 1001)) subtract = relay.subtract(conv, tile) func = subtract expr = relay.Function(relay.analysis.free_vars(func), func)