-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[microNPU] Add support for TFLite FULLY_CONNECTED #10345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8c0ea73
dab5c5e
b79fec2
ef1d576
d80302a
7b177bd
a4741b9
83d9ee1
97cd5ce
3bc81a5
ae6827c
18bd546
efbe30e
fcca2d2
06aa4d9
04fd825
82dc516
f924021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs): | |||||
| pass | ||||||
|
|
||||||
|
|
||||||
| class FullyConnectedRewriter(DFPatternCallback): | ||||||
| """Legalize Fully Connected (with bias and clip) to an EthosU operator""" | ||||||
|
|
||||||
| def __init__(self): | ||||||
| super().__init__(require_type=True) | ||||||
| self.pattern = ( | ||||||
| wildcard().has_attr({"Composite": ethosu_patterns.FullyConnectedParams.composite_name}) | ||||||
| )(wildcard()) | ||||||
|
|
||||||
| def callback(self, pre, post, node_map): | ||||||
| params = ethosu_patterns.FullyConnectedParams(post.op.body) | ||||||
| params.ifm.tensor = post.args[0] | ||||||
| activation_map = {"clip": "CLIP"} | ||||||
|
|
||||||
| # IFM reshapes | ||||||
| ifm = post.args[0] | ||||||
| if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == params.ifm.shape[2] == 1: | ||||||
| ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1])) | ||||||
|
||||||
| ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1])) | |
| ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1])) |
should be safer since the NPU doesn't support IFMs with a batch size anything other than 1 and this kind of fully connected wouldn't be offloaded anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect there isn't a test case that exercises this case since on line 1700 this pass runs after the no op legalizer, so the last reshape won't have a following identity op and will fall over in TE
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """This is the pass that wraps the AddRewriter""" | |
| """This is the pass that wraps the FullyConnectedRewriter""" |
manupak marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1537,6 +1537,106 @@ def squeeze_pattern(): | |
| return is_op("squeeze")(wildcard()) | ||
|
|
||
|
|
||
| class FullyConnectedParams: | ||
| """ | ||
| This class will parse a call to an ethos-u.fully_connected composite | ||
| function and extract the parameter information. | ||
| """ | ||
|
|
||
| composite_name = "ethosu.fully_connected" | ||
| activation_map = {"clip": "CLIP"} | ||
|
||
|
|
||
| @requires_vela | ||
| def __init__(self, func_body): | ||
| from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore | ||
| from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs | ||
| from tvm.relay.backend.contrib.ethosu.util import RequantArgs | ||
|
|
||
| activation = None | ||
| if str(func_body.op) in self.activation_map.keys(): | ||
| activation = func_body | ||
| requantize_op = activation.args[0] | ||
| else: | ||
| requantize_op = func_body | ||
|
|
||
| bias_add = requantize_op.args[0] | ||
| qnn_dense = bias_add.args[0] | ||
|
|
||
| # We consider the weights & biases as params as they should be constant | ||
| self.weights = TensorParams( | ||
| qnn_dense.args[QDenseArgs.weights.value], | ||
| "OI", | ||
| qnn_dense.args[QDenseArgs.weights_scale.value], | ||
| qnn_dense.args[QDenseArgs.weights_zero_point.value], | ||
| ) | ||
| self.biases = TensorParams( | ||
| bias_add.args[BiasAddArgs.BIASES.value], | ||
| None, | ||
| requantize_op.args[RequantArgs.IFM_SCALE.value], | ||
| requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], | ||
| ) | ||
| self.ifm = TensorParams( | ||
| qnn_dense.args[QDenseArgs.ifm.value], | ||
|
||
| None, | ||
| qnn_dense.args[QDenseArgs.ifm_scale.value], | ||
| qnn_dense.args[QDenseArgs.ifm_zero_point.value], | ||
| ) | ||
| self.ofm = TensorParams( | ||
| func_body, | ||
| None, | ||
| requantize_op.args[RequantArgs.OFM_SCALE.value], | ||
| requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], | ||
| ) | ||
|
|
||
| self.activation = activation | ||
|
|
||
| def is_valid(self): | ||
| """ | ||
| Checks whether Fully Connected has compatible attributes with HW | ||
| """ | ||
|
|
||
| def check_weights_fc(weights): | ||
| """Checks whether weight tensor is compatible with HW""" | ||
| weights_limit = 127 * 65536 | ||
| # A saturation upper bound check for accumulators | ||
| weights.values = weights.values - weights.q_params.zero_point | ||
| axis = 1 | ||
| sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) | ||
| if not sum_weights <= weights_limit: | ||
| return False | ||
| return True | ||
|
|
||
| if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]): | ||
| return False | ||
| if not check_weights_fc(self.weights): | ||
| return False | ||
| if not check_bias(self.biases): | ||
| return False | ||
| if not check_batch_size(self.ifm): | ||
| return False | ||
| # Check input shape | ||
| if len(self.ifm.shape) < 2: | ||
| return False | ||
manupak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if not np.all(np.array(self.ifm.shape[:-1]) == 1): | ||
| # As we reshape the ifm from | ||
| # [n0, n1, ... , n_m] to [n0 * n1 * ... * n_{m-1}, n_m] | ||
| # all except the last dims need to be 1. | ||
| return False | ||
|
||
| return True | ||
|
|
||
|
|
||
| def qnn_fc_pattern(): | ||
| dense = is_op("qnn.dense")( | ||
| wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() | ||
| ) | ||
| bias_add = is_op("nn.bias_add")(dense, is_constant()) | ||
| req = is_op("qnn.requantize")( | ||
| dense | bias_add, is_constant(), is_constant(), is_constant(), is_constant() | ||
|
||
| ) | ||
| optional_clip = req.optional(is_op("clip")) | ||
| return optional_clip | ||
|
|
||
|
|
||
| @register_pattern_table("ethos-u") | ||
| def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: | ||
| return [ | ||
|
|
@@ -1652,6 +1752,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal | |
| squeeze_pattern(), | ||
| lambda pat: SqueezeParams(pat).is_valid(), | ||
| ), | ||
| ( | ||
| FullyConnectedParams.composite_name, | ||
| qnn_fc_pattern(), | ||
| lambda pat: FullyConnectedParams(pat).is_valid(), | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1167,5 +1167,26 @@ def leaky_relu_func(x): | |
| _compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("accel_type", ACCEL_TYPES) | ||
| @pytest.mark.parametrize("units", [32, 64]) | ||
| @pytest.mark.parametrize("use_bias", [True, False]) | ||
| @pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) | ||
| def test_tflite_fully_connected( | ||
| accel_type, | ||
| units, | ||
| use_bias, | ||
| activation_function, | ||
| ): | ||
| @tf.function | ||
| def fully_connected(): | ||
| return tf.keras.layers.Dense( | ||
|
||
| units=units, | ||
| activation=activation_function, | ||
| use_bias=use_bias, | ||
| ) | ||
|
|
||
| _compare_tvm_with_tflite(fully_connected, (1, 3, units, 1), accel_type) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2346,5 +2346,87 @@ def verify(ext_func): | |
| verify(mod["tvmgen_default_ethos_u_main_0"]) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("units", [32, 64]) | ||
| @pytest.mark.parametrize("use_bias", [True, False]) | ||
| @pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) | ||
| def test_tflite_fully_connected( | ||
| units, | ||
| use_bias, | ||
| activation_function, | ||
| ): | ||
| dtype = "int8" | ||
|
|
||
| def create_tflite_graph(): | ||
| class Model(tf.Module): | ||
| @tf.function | ||
| def fully_connected(self, x): | ||
| return tf.keras.layers.Dense( | ||
| units=units, | ||
| activation=activation_function, | ||
| use_bias=use_bias, | ||
| )(x) | ||
|
|
||
| model = Model() | ||
| concrete_func = model.fully_connected.get_concrete_function( | ||
| tf.TensorSpec([1, 3, units, 1], dtype=tf.float32) | ||
| ) | ||
|
|
||
| # Convert the model | ||
| def representative_dataset(): | ||
| for _ in range(100): | ||
| data = np.random.rand(*tuple([1, 3, units, 1])) | ||
| yield [data.astype(np.float32)] | ||
|
|
||
| converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) | ||
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||
| converter.representative_dataset = representative_dataset | ||
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | ||
| converter.inference_input_type = tf.int8 | ||
| converter.inference_output_type = tf.int8 | ||
| tflite_model = converter.convert() | ||
| return tflite_model | ||
|
|
||
| def verify(ext_func): | ||
| op = ext_func.body | ||
| ofm_channels = op.attrs.ofm_channels | ||
|
|
||
| # check IFM | ||
| ifm = op.args[0].checked_type | ||
| assert list([1, 3, units, 1]) == list([1, 3, units, 1]) | ||
|
||
| assert str(ifm.dtype) == dtype | ||
| assert ifm.shape[3] == ofm_channels | ||
|
|
||
| # Check that scale_bias matches weight tensor | ||
| assert list(op.args[2].checked_type.shape)[0] == ofm_channels | ||
|
|
||
| if activation_function == "RELU": | ||
| assert str(op.attrs.activation) == "CLIP" | ||
|
|
||
| dense_pattern_table = [ | ||
|
||
| ( | ||
| ethosu.FullyConnectedParams.composite_name, | ||
| ethosu.qnn_fc_pattern(), | ||
| lambda pat: ethosu.FullyConnectedParams(pat).is_valid(), | ||
| ) | ||
| ] | ||
|
|
||
| tflite_graph = create_tflite_graph() | ||
| tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) | ||
|
|
||
| mod, params = relay.frontend.from_tflite( | ||
| tflite_model, | ||
| shape_dict={"input": [1, 3, units, 1]}, | ||
| dtype_dict={"input": dtype}, | ||
| ) | ||
|
|
||
| mod["main"] = bind_params_by_name(mod["main"], params) | ||
| mod = partition_ethosu_by_table(mod, dense_pattern_table) | ||
|
|
||
| mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( | ||
| legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"] | ||
| ) | ||
| verify(mod["tvmgen_default_ethos_u_main_0"]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't expect that dict to expand, so we can just do
if activation == "clip":etc