Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,87 @@ def callback(
return reduced_op


class SumRewriter(DFPatternCallback):
"""
Convert ethosu.sum composite functions to pooling operations
"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.SumParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:

params = ethosu_patterns.SumParams(post.op.body)

ifm_shape = params.ifm.shape
ofm_shape = params.ofm.shape
lut = relay.const([], "int8")
reduced_op = post.args[0]

# Enforce 4d input
if len(ifm_shape) == 3:
ifm_shape = [1, params.height, params.width, ifm_shape[2]]
reduced_op = relay.reshape(reduced_op, ifm_shape)

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

reduced_op = ethosu_ops.ethosu_pooling(
ifm=reduced_op,
lut=lut,
pooling_type="SUM",
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=0,
pool_shape=(1, 1),
ofm_channels=1,
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=params.ifm.layout,
ofm_layout=params.ofm.layout,
rounding_mode="NATURAL",
)

# Convert tensor dtype from int32 to int8
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32")
reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=0.0,
ifm_zero_point=0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int8",
)

# Reshape to original ofm shape
if len(ofm_shape) < 4:
reduced_op = relay.reshape(reduced_op, ofm_shape)

return reduced_op


class ConcatRewriter(DFPatternCallback):
"""The newer versions of TFLite converters return a concatenate operator that concatenates
tensors with same QNN params (if the QNN params of tensors were initially different,
Expand Down Expand Up @@ -1443,6 +1524,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
HardSwishRewriter(),
LeakyReLURewriter(),
MeanRewriter(),
SumRewriter(),
ConcatRewriter(),
SigmoidRewriter(),
RequantizeRewriter(),
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,21 @@ def pooling_compute(
padding = [int(v) for v in padding]
stride_h, stride_w = [int(v) for v in strides]
pool_shape_h, pool_shape_w = [int(v) for v in pool_shape]
ifm_channels = ofm_channels if pooling_type != "SUM" else ifm.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't block on this as the functionality remains the same, but wondering if ofm_channels ever made sense here and whether we can just use ifm.shape[-1] for all pooling types?

upscale_factor = 2 if upscale != "NONE" else 1

# Compute operation for the IFM DMA pipeline
dmaed_ifm = dma_ifm_compute(
ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding, upscale_factor
ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, padding, upscale_factor
)

# Pooling compute operation
ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1
ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1
rh = te.reduce_axis((0, pool_shape_h), name="ry")
rw = te.reduce_axis((0, pool_shape_w), name="rx")
rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), name="rc")
ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32"

pooling_attrs = {
"op": "ethosu_pooling",
Expand Down Expand Up @@ -149,10 +152,10 @@ def pooling_compute(
pooling = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.max(
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
ifm.dtype
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc + rc) + lut_expr).astype(
ofm_dtype
),
axis=[rh, rw],
axis=[rh, rw, rc],
),
name="ethosu_pooling",
attrs=pooling_attrs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,10 @@ def _create_npu_quantization(
"""This is a helper function to capture a list
of arguments to create Vela NpuQuantization object.
"""
return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point))
scale = float(scale)
if scale == 0.0:
scale = None
return vapi.NpuQuantization(scale_f32=scale, zero_point=int(zero_point))


def _create_npu_weights_zero_point(
Expand Down Expand Up @@ -960,6 +963,8 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling):
npu_pooling_op = vapi.NpuPoolingOp.AVERAGE
elif pooling_type == "MAX":
npu_pooling_op = vapi.NpuPoolingOp.MAX
elif pooling_type == "SUM":
npu_pooling_op = vapi.NpuPoolingOp.REDUCE_SUM

npu_pooling_op = vapi.NpuPoolingOperation(npu_pooling_op)
npu_pooling_op.ifm = _create_npu_feature_map(serial_pooling.ifm)
Expand Down
90 changes: 90 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,91 @@ def mean_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return pattern


class SumParams:
"""
This class will parse a call to ethosu.sum composite function
and extract the parameter information.
"""

composite_name = "ethos-u.sum"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

clip = None
if str(func_body.op.name) == "clip":
clip = func_body
requantize = clip.args[0]
else:
requantize = func_body

sum_op = requantize.args[0]
attrs = sum_op.attrs
cast = sum_op.args[0]

layout = "NHWC"
self.ifm = TensorParams(
cast.args[0],
layout,
requantize.args[RequantArgs.IFM_SCALE.value],
requantize.args[RequantArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
requantize,
layout,
requantize.args[RequantArgs.OFM_SCALE.value],
requantize.args[RequantArgs.OFM_ZERO_POINT.value],
)

self.activation = clip

ifm_shape = self.ifm.shape
self.height = ifm_shape[0] if len(ifm_shape) in (2, 3) else ifm_shape[1]
self.width = ifm_shape[1] if len(ifm_shape) in (2, 3) else ifm_shape[2]
self.keepdims = attrs.keepdims

self.axis = list(sorted(attrs.axis))
if attrs.exclude:
self.axis = [i for i in range(len(self.ifm.shape)) if i not in self.axis]

def is_valid(self) -> bool:
"""
Checks whether Sum has compatible attributes with HW.
"""

ifm_shape_len = len(self.ifm.shape)

if not check_valid_dtypes([self.ifm], [np.uint8, np.int8, np.int16, np.int32]):
return False
if not check_valid_dtypes([self.ofm], [np.int8]):
return False
if not ifm_shape_len in (3, 4):
return False
if ifm_shape_len == 3 and self.axis not in [[2]]:
return False
if ifm_shape_len == 4 and self.axis not in [[3]]:
return False

return True


def sum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for sum.
"""
pattern = is_op("cast")(wildcard())
pattern = is_op("sum")(pattern)
pattern = is_op("qnn.requantize")(
pattern,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
pattern = pattern.optional(is_op("clip"))
return pattern


class ConcatParams:
"""
This class will parse a call to a ethos-u.concat composite function
Expand Down Expand Up @@ -1995,6 +2080,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
mean_pattern(),
lambda pat: MeanParams(pat).is_valid(),
),
(
SumParams.composite_name,
sum_pattern(),
lambda pat: SumParams(pat).is_valid(),
),
(
LeakyReLUParams.composite_name,
leaky_relu_pattern(),
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/contrib/ethosu/op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode<EthosuPoolingAttrs> {

TVM_DECLARE_ATTRS(EthosuPoolingAttrs, "relay.attrs.EthosuPoolingAttrs") {
TVM_ATTR_FIELD(pooling_type)
.describe("The type of the pooling. 'AVG' - average pool, 'MAX' - max pool.");
.describe(
"The type of the pooling. 'AVG' - average pool, 'MAX' - max pool, "
"'SUM' - reduce sum pool.");
TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor.");
TVM_ATTR_FIELD(ifm_zero_point)
.describe("The quantization zero point for the Input Feature Map tensor.");
Expand Down
23 changes: 19 additions & 4 deletions src/relay/op/contrib/ethosu/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,28 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att

const String operator_name = "ethosu_pooling";

if (param->pooling_type != "AVG" && param->pooling_type != "MAX") {
if (param->pooling_type != "AVG" && param->pooling_type != "MAX" &&
param->pooling_type != "SUM") {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected " << operator_name
<< " type 'AVG' or 'MAX' but was " << param->pooling_type);
<< " type 'AVG', 'MAX', or 'SUM' but was "
<< param->pooling_type);
return false;
}

CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm",
std::initializer_list<DataType> max_avg_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8),
DataType::Int(16)};
std::initializer_list<DataType> sum_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8),
DataType::Int(16), DataType::Int(32)};

std::initializer_list<DataType>& allowed_ifm_dtypes = max_avg_pooling_ifm_dtypes;
auto ofm_dtype = ifm->dtype;
if (param->pooling_type == "SUM") {
allowed_ifm_dtypes = sum_pooling_ifm_dtypes;
ofm_dtype = DataType::Int(32);
}

CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm",
param->pooling_type);

CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);
Expand All @@ -67,7 +81,8 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att
auto ofm_shape = EthosuInferKernelOutput(
ifm_shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels,
Array<IndexExpr>({1, 1}), param->strides, param->padding);
reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype));

reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype));
return true;
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,36 @@ def create_mod_from_relay():
infra.verify_source(compiled_models, test_runner)


@pytest.mark.parametrize(
"accel_type",
ACCEL_TYPES,
)
@pytest.mark.parametrize(
"ifm_shape, axis, keepdims, relu",
[
[(1, 4, 2, 8), 3, False, False],
[(1, 4, 4, 1), 3, False, True],
[(3, 5, 7), 2, False, True],
[(1, 4, 2, 8), 3, True, False],
[(3, 5, 7), 2, True, False],
],
)
def test_ethosu_sum(accel_type, ifm_shape, axis, keepdims, relu):
np.random.seed(0)

@tf.function
def sum_func(x):
op = tf.math.reduce_sum(x, axis=axis, keepdims=keepdims)
return tf.nn.relu(op) if relu else op

infra.compare_tvm_with_tflite(
sum_func,
[ifm_shape],
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
Expand Down
Loading