Skip to content

Commit fdfd16c

Browse files
[microNPU][ETHOSU] MatMul legalization support (#15780)
NPU has a restriction that weights must be constant, so the matrix multiplication operation was expressed using split, elementwise multiplication, reduce sum, concatenations operations.
1 parent 71caa19 commit fdfd16c

File tree

4 files changed

+281
-1
lines changed

4 files changed

+281
-1
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,101 @@ def callback(self, pre, post, node_map):
14011401
return ethosu_fc
14021402

14031403

1404+
class MatMulRewriter(DFPatternCallback):
1405+
"""Legalize matrix multiplication to an NPU operator"""
1406+
1407+
def __init__(self):
1408+
super().__init__(require_type=True)
1409+
self.pattern = (
1410+
wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name})
1411+
)(wildcard(), wildcard())
1412+
1413+
def callback(self, pre, post, node_map):
1414+
params = ethosu_patterns.MatMulParams(post.op.body)
1415+
ifm = post.args[0]
1416+
ifm2 = post.args[1]
1417+
lut = relay.const([], dtype="int8")
1418+
activation_map = {"clip": "CLIP"}
1419+
if params.activation:
1420+
activation = activation_map[params.activation.op.name]
1421+
clip_min = int(params.activation.attrs.a_min)
1422+
clip_max = int(params.activation.attrs.a_max)
1423+
else:
1424+
activation = "NONE"
1425+
clip_min = 0
1426+
clip_max = 0
1427+
1428+
# Reshape ifm to NHWC
1429+
ifm = relay.reshape(ifm, (1, 1, *params.ifm.shape))
1430+
# Split the second matrix to get columns
1431+
columns = list(relay.op.split(ifm2, params.ofm.shape[-1], axis=0))
1432+
1433+
res_columns = []
1434+
for column in columns:
1435+
ifm2 = relay.reshape(column, (1, 1, 1, params.ifm.shape[-1]))
1436+
# Multiplying the first matrix by a column
1437+
ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
1438+
ifm=ifm,
1439+
ifm2=ifm2,
1440+
lut=lut,
1441+
operator_type="MUL",
1442+
ifm_zero_point=int(params.ifm.q_params.zero_point),
1443+
ifm_scale=0.0,
1444+
ifm2_zero_point=int(params.weights.q_params.zero_point),
1445+
ifm2_scale=0.0,
1446+
ofm_scale=0.0,
1447+
ofm_zero_point=0,
1448+
ifm_channels=params.ifm.shape[-1],
1449+
ifm2_channels=params.ifm.shape[-1],
1450+
reversed_operands=False,
1451+
ofm_dtype="int32",
1452+
)
1453+
1454+
# Use reduce sum to get result column
1455+
reduce_sum = ethosu_ops.ethosu_pooling(
1456+
ifm=ethosu_binary_elementwise,
1457+
lut=lut,
1458+
pooling_type="SUM",
1459+
ifm_zero_point=0,
1460+
ifm_scale=float(params.weights.q_params.scale_f32)
1461+
* float(params.ifm.q_params.scale_f32),
1462+
ofm_scale=float(params.ofm.q_params.scale_f32),
1463+
ofm_zero_point=0,
1464+
pool_shape=(1, 1),
1465+
ofm_channels=1,
1466+
ofm_dtype="int32",
1467+
activation=activation,
1468+
clip_min=clip_min,
1469+
clip_max=clip_max,
1470+
rounding_mode="NATURAL",
1471+
)
1472+
1473+
# Convert tensor dtype from int32 to int8
1474+
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32")
1475+
reduce_sum = ethosu_ops.ethosu_binary_elementwise(
1476+
ifm=reduce_sum,
1477+
ifm2=scalar_tensor,
1478+
lut=lut,
1479+
operator_type="MUL",
1480+
ifm_scale=0.0,
1481+
ifm_zero_point=0,
1482+
ifm2_scale=0.0,
1483+
ifm2_zero_point=0,
1484+
ofm_scale=0.0,
1485+
ofm_zero_point=int(params.ofm.q_params.zero_point),
1486+
ifm_channels=1,
1487+
ifm2_channels=1,
1488+
reversed_operands=False,
1489+
ofm_dtype="int8",
1490+
)
1491+
1492+
res_columns.append(reduce_sum)
1493+
1494+
# Concatenate result columns
1495+
concat = relay.op.concatenate(relay.Tuple(res_columns), axis=3)
1496+
return relay.reshape(concat, params.ofm.shape)
1497+
1498+
14041499
class PadRewriter(DFPatternCallback):
14051500
"""Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d
14061501
operator"""
@@ -1546,12 +1641,13 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
15461641
"""
15471642
rewriters = [
15481643
PartitionedSplitRewriter(),
1644+
FullyConnectedRewriter(),
1645+
MatMulRewriter(),
15491646
SplitRewriter(),
15501647
ChannelPadRewriter(),
15511648
Conv2DRewriter(),
15521649
Conv2DTransposeRewriter(),
15531650
DepthwiseConv2DRewriter(),
1554-
FullyConnectedRewriter(),
15551651
MaxPoolingRewriter(),
15561652
AvgPoolingRewriter(),
15571653
PadRewriter(),

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,44 @@ def qnn_fc_pattern():
19001900
return optional_clip
19011901

19021902

1903+
class MatMulParams(FullyConnectedParams):
1904+
"""
1905+
This class will parse a call to an ethos-u.matmul composite
1906+
function and extract the parameter information.
1907+
"""
1908+
1909+
composite_name = "ethos-u.matmul"
1910+
1911+
@requires_vela
1912+
def __init__(self, func_body):
1913+
FullyConnectedParams.__init__(self, func_body)
1914+
1915+
def is_valid(self) -> bool:
1916+
"""
1917+
Checks whether matrix multiplication has compatible attributes with HW
1918+
"""
1919+
1920+
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
1921+
return False
1922+
if not len(self.ifm.shape) == 2:
1923+
return False
1924+
if not len(self.ofm.shape) == 2:
1925+
return False
1926+
# The weights must be transposed
1927+
if self.ifm.shape[1] != self.weights.shape[1]:
1928+
return False
1929+
return True
1930+
1931+
1932+
def matmul_pattern():
1933+
dense = is_op("qnn.dense")(
1934+
wildcard(), wildcard(), is_constant(), is_constant(), is_constant(), is_constant()
1935+
)
1936+
req = is_op("qnn.requantize")(dense, is_constant(), is_constant(), is_constant(), is_constant())
1937+
optional_clip = req.optional(is_op("clip"))
1938+
return optional_clip
1939+
1940+
19031941
class HardSwishParams:
19041942
"""
19051943
This class will parse a call to a ethos-u.hard_swish composite function
@@ -2185,6 +2223,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
21852223
qnn_fc_pattern(),
21862224
lambda pat: FullyConnectedParams(pat).is_valid(),
21872225
),
2226+
(
2227+
MatMulParams.composite_name,
2228+
matmul_pattern(),
2229+
lambda pat: MatMulParams(pat).is_valid(),
2230+
),
21882231
(
21892232
MaxPool2DParams.composite_name,
21902233
qnn_maxpool2d_pattern(),

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,30 @@ def fully_connected(x):
15641564
)
15651565

15661566

1567+
@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
1568+
@pytest.mark.parametrize("ifm_shape", [(1, 16), (4, 8)])
1569+
@pytest.mark.parametrize("ofm_channels", [8, 32])
1570+
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
1571+
def test_tflite_matmul(
1572+
accel_type,
1573+
ifm_shape,
1574+
ofm_channels,
1575+
activation_function,
1576+
):
1577+
np.random.seed(0)
1578+
1579+
@tf.function
1580+
def matmul(x, y):
1581+
x = tf.matmul(x, y, transpose_b=True)
1582+
if activation_function == "RELU":
1583+
x = tf.nn.relu(x)
1584+
return x
1585+
1586+
infra.compare_tvm_with_tflite(
1587+
matmul, [ifm_shape, [ofm_channels, ifm_shape[-1]]], accel_type, enable_cascader=False
1588+
)
1589+
1590+
15671591
@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
15681592
def test_tflite_subtract_sigmoid(accel_type):
15691593
np.random.seed(0)

tests/python/contrib/test_ethosu/test_legalize.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3806,5 +3806,122 @@ def representative_dataset():
38063806
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"
38073807

38083808

3809+
def test_tflite_matmul():
3810+
ifm_shape = [1, 4]
3811+
ifm2_shape = [2, 4]
3812+
ifm_shapes = [ifm_shape, ifm2_shape]
3813+
ofm_shape = [ifm_shape[0], ifm2_shape[0]]
3814+
dtype = "int8"
3815+
3816+
def create_tflite_graph():
3817+
class Model(tf.Module):
3818+
@tf.function
3819+
def matmul(self, x, y):
3820+
res = tf.matmul(x, y, transpose_b=True)
3821+
return res
3822+
3823+
model = Model()
3824+
concrete_func = model.matmul.get_concrete_function(
3825+
*[tf.TensorSpec(shape, tf.float32) for shape in ifm_shapes]
3826+
)
3827+
# Convert the model
3828+
def representative_dataset():
3829+
for _ in range(100):
3830+
datas = [np.random.rand(*shape) for shape in ifm_shapes]
3831+
yield [data.astype(np.float32) for data in datas]
3832+
3833+
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
3834+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
3835+
converter.representative_dataset = representative_dataset
3836+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
3837+
converter.inference_input_type = tf.int8
3838+
converter.inference_output_type = tf.int8
3839+
tflite_model = converter.convert()
3840+
return tflite_model
3841+
3842+
def verify(ext_func):
3843+
ofm = ext_func.body
3844+
ops = []
3845+
3846+
def _visit(stmt):
3847+
if isinstance(stmt, relay.expr.Call):
3848+
ops.append(stmt)
3849+
3850+
relay.analysis.post_order_visit(ofm, _visit)
3851+
ofm_checked_type = ofm.checked_type
3852+
ofm_channels = ofm_shape[-1]
3853+
3854+
# check IFM
3855+
ifm = ops[1].checked_type
3856+
assert list(ifm.shape) == ifm_shape
3857+
assert str(ifm.dtype) == dtype
3858+
3859+
# check IFM2
3860+
ifm2 = ops[3].checked_type
3861+
assert list(ifm2.shape) == ifm2_shape
3862+
assert str(ifm2.dtype) == dtype
3863+
3864+
# check split
3865+
split = ops[4]
3866+
split_checked_types = list(split.checked_type.fields)
3867+
assert split.op.name == "split"
3868+
assert split.attrs.axis == 0
3869+
assert int(split.attrs.indices_or_sections) == ofm_channels
3870+
for split_checked_type in split_checked_types:
3871+
assert list(split_checked_type.shape) == ifm_shape
3872+
assert str(split_checked_type.dtype) == dtype
3873+
3874+
# check MUL
3875+
mul_ops = [ops[6], ops[10]]
3876+
for mul_op in mul_ops:
3877+
assert mul_op.op.name == "contrib.ethosu.binary_elementwise"
3878+
assert mul_op.attrs.operator_type == "MUL"
3879+
assert mul_op.attrs.ofm_dtype == "int32"
3880+
3881+
# check reduce sum
3882+
reduce_sum_ops = [ops[7], ops[11]]
3883+
for reduce_sum_op in reduce_sum_ops:
3884+
assert reduce_sum_op.op.name == "contrib.ethosu.pooling"
3885+
assert reduce_sum_op.attrs.pooling_type == "SUM"
3886+
assert list(reduce_sum_op.checked_type.shape) == [1, 1, 1, 1]
3887+
3888+
# check concatenation
3889+
concatenation = ofm.args[0]
3890+
concatenation_shape = concatenation.checked_type.shape
3891+
assert concatenation.op.name == "concatenate"
3892+
assert list(concatenation_shape) == [1, 1, 1, ofm_channels]
3893+
3894+
# check OFM
3895+
assert ofm.op.name == "reshape"
3896+
assert list(ofm_checked_type.shape) == ofm_shape
3897+
assert str(ofm_checked_type.dtype) == dtype
3898+
3899+
matmul_pattern_table = [
3900+
(
3901+
ethosu.MatMulParams.composite_name,
3902+
ethosu.matmul_pattern(),
3903+
lambda pat: ethosu.MatMulParams(pat).is_valid(),
3904+
)
3905+
]
3906+
3907+
tflite_graph = create_tflite_graph()
3908+
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
3909+
3910+
mod, params = relay.frontend.from_tflite(
3911+
tflite_model,
3912+
shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(ifm_shapes)},
3913+
dtype_dict={("ifm" + str(i)): dtype for i, _ in enumerate(ifm_shapes)},
3914+
)
3915+
3916+
mod["main"] = bind_params_by_name(mod["main"], params)
3917+
mod = partition_ethosu_by_table(mod, matmul_pattern_table)
3918+
3919+
mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
3920+
legalize.MatMulRewriter(), mod["tvmgen_default_ethos_u_main_0"]
3921+
)
3922+
3923+
verify(mod["tvmgen_default_ethos_u_main_0"])
3924+
3925+
38093926
if __name__ == "__main__":
38103927
tvm.testing.main()

0 commit comments

Comments
 (0)