-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[microNPU] Add support for TFLite FULLY_CONNECTED #10345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8c0ea73
dab5c5e
b79fec2
ef1d576
d80302a
7b177bd
a4741b9
83d9ee1
97cd5ce
3bc81a5
ae6827c
18bd546
efbe30e
fcca2d2
06aa4d9
04fd825
82dc516
f924021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
|
|
||
| import tvm # type: ignore | ||
| from tvm import relay | ||
| from tvm.relay.expr import Constant, Call # type: ignore | ||
| from tvm.relay.expr import Constant, Call | ||
|
||
| from tvm.relay.op.contrib.register import register_pattern_table # type: ignore | ||
| from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore | ||
| from tvm.relay.build_module import bind_params_by_name # type: ignore | ||
|
|
@@ -1103,7 +1103,10 @@ def is_valid(self): | |
| """ | ||
| if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): | ||
| return False | ||
| return True | ||
| return True # optional_bias_add = ( | ||
|
|
||
| # is_op("nn.bias_add")(dense, is_constant()) | dense | ||
| # ) | ||
|
||
|
|
||
|
|
||
| class TanhParams(LutActivationParams): | ||
|
|
@@ -1537,6 +1540,112 @@ def squeeze_pattern(): | |
| return is_op("squeeze")(wildcard()) | ||
|
|
||
|
|
||
| class FullyConnectedParams: | ||
| """ | ||
| This class will parse a call to an ethos-u.fully_connected composite | ||
| function and extract the parameter information. | ||
| """ | ||
|
|
||
| composite_name = "ethos-u.fully_connected" | ||
|
|
||
| @requires_vela | ||
| def __init__(self, func_body): | ||
| from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore | ||
| from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs | ||
| from tvm.relay.backend.contrib.ethosu.util import RequantArgs | ||
|
|
||
| self.activation = None | ||
| if str(func_body.op) == "clip": | ||
| self.activation = func_body | ||
| requantize_op = self.activation.args[0] | ||
| else: | ||
| requantize_op = func_body | ||
|
|
||
| call = requantize_op.args[0] | ||
| if str(requantize_op.args[0].op) == "nn.bias_add": | ||
| bias_add = call | ||
| qnn_dense = call.args[0] | ||
| else: | ||
| bias_add = None | ||
| qnn_dense = call | ||
|
|
||
| # weights & biases are params as they should be constant | ||
| self.weights = TensorParams( | ||
| qnn_dense.args[QDenseArgs.WEIGHTS.value], | ||
| None, | ||
| qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value], | ||
| qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value], | ||
| ) | ||
| self.biases = ( | ||
| TensorParams( | ||
| bias_add.args[BiasAddArgs.BIASES.value], | ||
| None, | ||
| requantize_op.args[RequantArgs.IFM_SCALE.value], | ||
| requantize_op.args[RequantArgs.IFM_ZERO_POINT.value], | ||
| ) | ||
| if bias_add | ||
| else None | ||
| ) | ||
| self.ifm = TensorParams( | ||
| qnn_dense.args[QDenseArgs.IFM.value], | ||
| None, | ||
| qnn_dense.args[QDenseArgs.IFM_SCALE.value], | ||
| qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value], | ||
| ) | ||
| self.ofm = TensorParams( | ||
| func_body, | ||
| None, | ||
| requantize_op.args[RequantArgs.OFM_SCALE.value], | ||
| requantize_op.args[RequantArgs.OFM_ZERO_POINT.value], | ||
| ) | ||
|
|
||
| def is_valid(self) -> bool: | ||
| """ | ||
| Checks whether Fully Connected has compatible attributes with HW | ||
| """ | ||
|
|
||
| def check_weights_fc(weights): | ||
| """Checks whether weight tensor is compatible with HW""" | ||
| weights_limit = 127 * 65536 | ||
| # A saturation upper bound check for accumulators | ||
| weights.values = weights.values - weights.q_params.zero_point | ||
| axis = 1 | ||
| sum_weights = np.amax(np.sum(np.absolute(weights.values), axis=axis)) | ||
| if not sum_weights <= weights_limit: | ||
| return False | ||
| return True | ||
|
|
||
| if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]): | ||
| return False | ||
| if not check_weights_fc(self.weights): | ||
| return False | ||
| if not check_bias(self.biases): | ||
| return False | ||
| if not check_batch_size(self.ifm): | ||
| return False | ||
| # Check input shape | ||
| if len(self.ifm.shape) < 2: | ||
| return False | ||
manupak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if not np.all(np.array(self.ifm.shape[:-1]) == 1): | ||
| # As we reshape the ifm from | ||
| # [n0, n1, ... , n_m] to [n0 * n1 * ... * n_{m-1}, n_m] | ||
| # all except the last dims need to be 1. | ||
| return False | ||
|
||
| return True | ||
|
|
||
|
|
||
| def qnn_fc_pattern(): | ||
| dense = is_op("qnn.dense")( | ||
| wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() | ||
| ) | ||
| optional_bias_add = is_op("nn.bias_add")(dense, is_constant()) | dense | ||
|
||
| req = is_op("qnn.requantize")( | ||
| dense | optional_bias_add, is_constant(), is_constant(), is_constant(), is_constant() | ||
| ) | ||
| optional_clip = req.optional(is_op("clip")) | ||
| return optional_clip | ||
|
|
||
|
|
||
| @register_pattern_table("ethos-u") | ||
| def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: | ||
| return [ | ||
|
|
@@ -1555,6 +1664,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal | |
| qnn_conv2d_transpose_pattern(), | ||
| lambda pat: QnnConv2DTransposeParams(pat).is_valid(), | ||
| ), | ||
| ( | ||
| FullyConnectedParams.composite_name, | ||
| qnn_fc_pattern(), | ||
| lambda pat: FullyConnectedParams(pat).is_valid(), | ||
| ), | ||
| ( | ||
| MaxPool2DParams.composite_name, | ||
| qnn_maxpool2d_pattern(), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect there isn't a test case that exercises this case since on line 1700 this pass runs after the no op legalizer, so the last reshape won't have a following identity op and will fall over in TE