Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,92 @@ def __call__(self, *args, **kwargs):
pass


class FullyConnectedRewriter(DFPatternCallback):
"""Legalize Fully Connected (with bias and clip) to an EthosU operator"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.FullyConnectedParams.composite_name})
)(wildcard())

def callback(self, pre, post, node_map):
params = ethosu_patterns.FullyConnectedParams(post.op.body)
params.ifm.tensor = post.args[0]

# IFM reshapes
ifm = post.args[0]
if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1:
ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1]))

# Weight transformations
weights_values = params.weights.values
weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2))
if params.activation:
activation = "CLIP"
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0
bias_values = (
params.biases.tensor.data.asnumpy()
if params.biases
else np.zeros((params.ofm.shape[-1]))
)
scale_bias = vela_api.pack_biases(
biases=bias_values,
ifm_scale=params.ifm.q_params.scale_f32,
ifm_dtype=np.dtype(params.ifm.dtype),
weight_scales=params.weights.q_params.scale_f32,
ofm_scale=params.ofm.q_params.scale_f32,
is_activation_tanh_or_sigmoid=False,
)
ethosu_fc = ethosu_ops.ethosu_conv2d(
ifm=ifm,
weight=relay.const(weights_values_ohwi, params.weights.values.dtype),
scale_bias=relay.const(scale_bias, "uint8"),
lut=relay.const([], dtype="int8"),
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
weight_zero_point=int(params.weights.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
kernel_shape=[1, 1],
ofm_channels=params.weights.shape[0],
strides=(1, 1),
padding=(0, 0, 0, 0),
dilation=(1, 1),
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
upscale="NONE",
ifm_layout="NHWC",
ofm_layout="NHWC",
)

if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == params.ofm.shape[2] == 1:
ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape)
Comment on lines +1652 to +1653
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect there isn't a test case that exercises this case since on line 1700 this pass runs after the no op legalizer, so the last reshape won't have a following identity op and will fall over in TE

return ethosu_fc


@ir.transform.module_pass(opt_level=1)
class LegalizeFullyConnected:
"""This is the pass that wraps the FullyConnectedRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(FullyConnectedRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1614,6 +1700,7 @@ def transform_module(
mod = LegalizeSqueeze()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeFullyConnected()(mod)
mod = LegalizeNoOps()(mod)
return mod

Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ class DequantizeArgs(Enum):
IFM_ZERO_POINT = 2


class QDenseArgs(Enum):
"""
This is a helper enum to access the correct index of
qnn.dense arguments
"""

IFM = 0
WEIGHTS = 1
IFM_ZERO_POINT = 2
WEIGHTS_ZERO_POINT = 3
IFM_SCALE = 4
WEIGHTS_SCALE = 5


def is_composite_func(func: relay.Function, name: str) -> bool:
"""
This method checks whether the call is to
Expand Down
118 changes: 116 additions & 2 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import tvm # type: ignore
from tvm import relay
from tvm.relay.expr import Constant, Call # type: ignore
from tvm.relay.expr import Constant, Call
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think we need to change this

from tvm.relay.op.contrib.register import register_pattern_table # type: ignore
from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore
from tvm.relay.build_module import bind_params_by_name # type: ignore
Expand Down Expand Up @@ -1103,7 +1103,10 @@ def is_valid(self):
"""
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
return True
return True # optional_bias_add = (

# is_op("nn.bias_add")(dense, is_constant()) | dense
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove comments :)



class TanhParams(LutActivationParams):
Expand Down Expand Up @@ -1537,6 +1540,112 @@ def squeeze_pattern():
return is_op("squeeze")(wildcard())


class FullyConnectedParams:
"""
This class will parse a call to an ethos-u.fully_connected composite
function and extract the parameter information.
"""

composite_name = "ethos-u.fully_connected"

@requires_vela
def __init__(self, func_body):
from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

self.activation = None
if str(func_body.op) == "clip":
self.activation = func_body
requantize_op = self.activation.args[0]
else:
requantize_op = func_body

call = requantize_op.args[0]
if str(requantize_op.args[0].op) == "nn.bias_add":
bias_add = call
qnn_dense = call.args[0]
else:
bias_add = None
qnn_dense = call

# weights & biases are params as they should be constant
self.weights = TensorParams(
qnn_dense.args[QDenseArgs.WEIGHTS.value],
None,
qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value],
qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value],
)
self.biases = (
TensorParams(
bias_add.args[BiasAddArgs.BIASES.value],
None,
requantize_op.args[RequantArgs.IFM_SCALE.value],
requantize_op.args[RequantArgs.IFM_ZERO_POINT.value],
)
if bias_add
else None
)
self.ifm = TensorParams(
qnn_dense.args[QDenseArgs.IFM.value],
None,
qnn_dense.args[QDenseArgs.IFM_SCALE.value],
qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
func_body,
None,
requantize_op.args[RequantArgs.OFM_SCALE.value],
requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
)

def is_valid(self) -> bool:
"""
Checks whether Fully Connected has compatible attributes with HW
"""

def check_weights_fc(weights):
"""Checks whether weight tensor is compatible with HW"""
weights_limit = 127 * 65536
# A saturation upper bound check for accumulators
weights.values = weights.values - weights.q_params.zero_point
axis = 1
sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis))
if not sum_weights <= weights_limit:
return False
return True

if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
return False
if not check_weights_fc(self.weights):
return False
if not check_bias(self.biases):
return False
if not check_batch_size(self.ifm):
return False
# Check input shape
if len(self.ifm.shape) < 2:
return False
if not np.all(np.array(self.ifm.shape[:-1]) == 1):
# As we reshape the ifm from
# [n0, n1, ... , n_m] to [n0 * n1 * ... * n_{m-1}, n_m]
# all except the last dims need to be 1.
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this due to reasoning in the above comment and since we already check that the batch size == 1 with check_batch_size above and we know that the ifm must be 2D

return True


def qnn_fc_pattern():
dense = is_op("qnn.dense")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) | dense
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should just be optional_bias_add = is_op("nn.bias_add")(dense, is_constant())

req = is_op("qnn.requantize")(
dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
optional_clip = req.optional(is_op("clip"))
return optional_clip


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand All @@ -1555,6 +1664,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_conv2d_transpose_pattern(),
lambda pat: QnnConv2DTransposeParams(pat).is_valid(),
),
(
FullyConnectedParams.composite_name,
qnn_fc_pattern(),
lambda pat: FullyConnectedParams(pat).is_valid(),
),
(
MaxPool2DParams.composite_name,
qnn_maxpool2d_pattern(),
Expand Down
30 changes: 30 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,5 +1167,35 @@ def leaky_relu_func(x):
_compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("activation_function", ["RELU", "NONE"])
def test_tflite_fully_connected(
accel_type,
ifm_shape,
ofm_channels,
use_bias,
activation_function,
):
@tf.function
def fully_connected(x):
bias_shape = ofm_channels
bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32)
w = tf.constant(
np.random.uniform(size=[ifm_shape[1], ofm_channels]),
dtype=tf.float32,
)
x = tf.matmul(x, w)
if use_bias:
x = tf.nn.bias_add(x, bias)
if activation_function:
x = tf.nn.relu(x)
return x

_compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
Loading