Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,101 @@ def callback(self, pre, post, node_map):
return ethosu_fc


class MatMulRewriter(DFPatternCallback):
"""Legalize matrix multiplication to an NPU operator"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name})
)(wildcard(), wildcard())

def callback(self, pre, post, node_map):
params = ethosu_patterns.MatMulParams(post.op.body)
ifm = post.args[0]
ifm2 = post.args[1]
lut = relay.const([], dtype="int8")
activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

# Reshape ifm to NHWC
ifm = relay.reshape(ifm, (1, 1, *params.ifm.shape))
# Split the second matrix to get columns
columns = list(relay.op.split(ifm2, params.ofm.shape[-1], axis=0))

res_columns = []
for column in columns:
ifm2 = relay.reshape(column, (1, 1, 1, params.ifm.shape[-1]))
# Multiplying the first matrix by a column
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
ifm=ifm,
ifm2=ifm2,
lut=lut,
operator_type="MUL",
ifm_zero_point=int(params.ifm.q_params.zero_point),
ifm_scale=0.0,
ifm2_zero_point=int(params.weights.q_params.zero_point),
ifm2_scale=0.0,
ofm_scale=0.0,
ofm_zero_point=0,
ifm_channels=params.ifm.shape[-1],
ifm2_channels=params.ifm.shape[-1],
reversed_operands=False,
ofm_dtype="int32",
)

# Use reduce sum to get result column
reduce_sum = ethosu_ops.ethosu_pooling(
ifm=ethosu_binary_elementwise,
lut=lut,
pooling_type="SUM",
ifm_zero_point=0,
ifm_scale=float(params.weights.q_params.scale_f32)
* float(params.ifm.q_params.scale_f32),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=0,
pool_shape=(1, 1),
ofm_channels=1,
ofm_dtype="int32",
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
rounding_mode="NATURAL",
)

# Convert tensor dtype from int32 to int8
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32")
reduce_sum = ethosu_ops.ethosu_binary_elementwise(
ifm=reduce_sum,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=0.0,
ifm_zero_point=0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int8",
)

res_columns.append(reduce_sum)

# Concatenate result columns
concat = relay.op.concatenate(relay.Tuple(res_columns), axis=3)
return relay.reshape(concat, params.ofm.shape)


class PadRewriter(DFPatternCallback):
"""Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d
operator"""
Expand Down Expand Up @@ -1546,12 +1641,13 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
"""
rewriters = [
PartitionedSplitRewriter(),
FullyConnectedRewriter(),
MatMulRewriter(),
SplitRewriter(),
ChannelPadRewriter(),
Conv2DRewriter(),
Conv2DTransposeRewriter(),
DepthwiseConv2DRewriter(),
FullyConnectedRewriter(),
MaxPoolingRewriter(),
AvgPoolingRewriter(),
PadRewriter(),
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,44 @@ def qnn_fc_pattern():
return optional_clip


class MatMulParams(FullyConnectedParams):
"""
This class will parse a call to an ethos-u.matmul composite
function and extract the parameter information.
"""

composite_name = "ethos-u.matmul"

@requires_vela
def __init__(self, func_body):
FullyConnectedParams.__init__(self, func_body)

def is_valid(self) -> bool:
"""
Checks whether matrix multiplication has compatible attributes with HW
"""

if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
if not len(self.ifm.shape) == 2:
return False
if not len(self.ofm.shape) == 2:
return False
# The weights must be transposed
if self.ifm.shape[1] != self.weights.shape[1]:
return False
return True


def matmul_pattern():
dense = is_op("qnn.dense")(
wildcard(), wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
)
req = is_op("qnn.requantize")(dense, is_constant(), is_constant(), is_constant(), is_constant())
optional_clip = req.optional(is_op("clip"))
return optional_clip


class HardSwishParams:
"""
This class will parse a call to a ethos-u.hard_swish composite function
Expand Down Expand Up @@ -2185,6 +2223,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_fc_pattern(),
lambda pat: FullyConnectedParams(pat).is_valid(),
),
(
MatMulParams.composite_name,
matmul_pattern(),
lambda pat: MatMulParams(pat).is_valid(),
),
(
MaxPool2DParams.composite_name,
qnn_maxpool2d_pattern(),
Expand Down
24 changes: 24 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,30 @@ def fully_connected(x):
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
@pytest.mark.parametrize("ifm_shape", [(1, 16), (4, 8)])
@pytest.mark.parametrize("ofm_channels", [8, 32])
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
def test_tflite_matmul(
accel_type,
ifm_shape,
ofm_channels,
activation_function,
):
np.random.seed(0)

@tf.function
def matmul(x, y):
x = tf.matmul(x, y, transpose_b=True)
if activation_function == "RELU":
x = tf.nn.relu(x)
return x

infra.compare_tvm_with_tflite(
matmul, [ifm_shape, [ofm_channels, ifm_shape[-1]]], accel_type, enable_cascader=False
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
def test_tflite_subtract_sigmoid(accel_type):
np.random.seed(0)
Expand Down
117 changes: 117 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3806,5 +3806,122 @@ def representative_dataset():
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"


def test_tflite_matmul():
ifm_shape = [1, 4]
ifm2_shape = [2, 4]
ifm_shapes = [ifm_shape, ifm2_shape]
ofm_shape = [ifm_shape[0], ifm2_shape[0]]
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def matmul(self, x, y):
res = tf.matmul(x, y, transpose_b=True)
return res

model = Model()
concrete_func = model.matmul.get_concrete_function(
*[tf.TensorSpec(shape, tf.float32) for shape in ifm_shapes]
)
# Convert the model
def representative_dataset():
for _ in range(100):
datas = [np.random.rand(*shape) for shape in ifm_shapes]
yield [data.astype(np.float32) for data in datas]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

def verify(ext_func):
ofm = ext_func.body
ops = []

def _visit(stmt):
if isinstance(stmt, relay.expr.Call):
ops.append(stmt)

relay.analysis.post_order_visit(ofm, _visit)
ofm_checked_type = ofm.checked_type
ofm_channels = ofm_shape[-1]

# check IFM
ifm = ops[1].checked_type
assert list(ifm.shape) == ifm_shape
assert str(ifm.dtype) == dtype

# check IFM2
ifm2 = ops[3].checked_type
assert list(ifm2.shape) == ifm2_shape
assert str(ifm2.dtype) == dtype

# check split
split = ops[4]
split_checked_types = list(split.checked_type.fields)
assert split.op.name == "split"
assert split.attrs.axis == 0
assert int(split.attrs.indices_or_sections) == ofm_channels
for split_checked_type in split_checked_types:
assert list(split_checked_type.shape) == ifm_shape
assert str(split_checked_type.dtype) == dtype

# check MUL
mul_ops = [ops[6], ops[10]]
for mul_op in mul_ops:
assert mul_op.op.name == "contrib.ethosu.binary_elementwise"
assert mul_op.attrs.operator_type == "MUL"
assert mul_op.attrs.ofm_dtype == "int32"

# check reduce sum
reduce_sum_ops = [ops[7], ops[11]]
for reduce_sum_op in reduce_sum_ops:
assert reduce_sum_op.op.name == "contrib.ethosu.pooling"
assert reduce_sum_op.attrs.pooling_type == "SUM"
assert list(reduce_sum_op.checked_type.shape) == [1, 1, 1, 1]

# check concatenation
concatenation = ofm.args[0]
concatenation_shape = concatenation.checked_type.shape
assert concatenation.op.name == "concatenate"
assert list(concatenation_shape) == [1, 1, 1, ofm_channels]

# check OFM
assert ofm.op.name == "reshape"
assert list(ofm_checked_type.shape) == ofm_shape
assert str(ofm_checked_type.dtype) == dtype

matmul_pattern_table = [
(
ethosu.MatMulParams.composite_name,
ethosu.matmul_pattern(),
lambda pat: ethosu.MatMulParams(pat).is_valid(),
)
]

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(ifm_shapes)},
dtype_dict={("ifm" + str(i)): dtype for i, _ in enumerate(ifm_shapes)},
)

mod["main"] = bind_params_by_name(mod["main"], params)
mod = partition_ethosu_by_table(mod, matmul_pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
legalize.MatMulRewriter(), mod["tvmgen_default_ethos_u_main_0"]
)

verify(mod["tvmgen_default_ethos_u_main_0"])


if __name__ == "__main__":
tvm.testing.main()