Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1b927d2
Initial (unfinished) QKeras to QONNX converter
vloncar Nov 10, 2021
19f6cab
Proper name for QKeras package
vloncar Nov 10, 2021
2f1fd54
Update minimum python version in github actions
vloncar Nov 10, 2021
d42c5e8
Make pre-commit pass.
heborras Nov 19, 2021
26654e0
Add TF as a dependency to allow tests to run.
heborras Nov 19, 2021
b552a29
Moved some Quant node attributes into inputs.
heborras Nov 19, 2021
85e13de
Fixed unknown shape in tensor shapes
heborras Nov 19, 2021
0b87806
Fixed undefined datatypes at Quant outputs.
heborras Nov 19, 2021
3779fe1
Set default names for intermediate outputs in keras converter.
heborras Nov 22, 2021
9ae60ed
Save returned clean-up model.
heborras Nov 26, 2021
de9f2a0
Add fix attaching shape of initializers for proper shape inference
thesps Mar 8, 2022
24ed08f
Add single layer QKeras tests and checks of numerical consistency
thesps Mar 8, 2022
449cb12
Tidy up. Use RandomUniform initializers in tests
thesps Mar 8, 2022
c830b01
Fix: copy the weights to the stripped Keras model, otherwise they are…
thesps Mar 8, 2022
923477b
Fixes to quantizer conversion for numerical correctness, and protecti…
thesps Mar 8, 2022
642a94c
reorder imports
jmduarte Apr 21, 2022
ad48742
run black
jmduarte Apr 21, 2022
dfd3346
flake8
jmduarte Apr 21, 2022
805b47f
update some names
jmduarte Apr 28, 2022
0609ff7
fix quant tests
jmduarte Jun 24, 2022
8dfb2b8
update to finn
jmduarte Jun 24, 2022
7a9d861
flake8
jmduarte Jun 24, 2022
66ef0bd
pre-commit hook working
jmduarte Jun 24, 2022
d5a0c22
Docstrings added
thephysicsboi Aug 19, 2022
8bae3a5
More comments added
thephysicsboi Aug 25, 2022
b3bedd7
corrected transpose issue with conv node
Nov 30, 2022
446a55b
removed onnx models
Nov 30, 2022
9dea347
pre-commit
jmduarte Nov 30, 2022
44a72a0
reseed tensorflow
jmduarte Nov 30, 2022
d959bd9
Added new tests
Dec 14, 2022
9921ef9
rebased with main
Dec 16, 2022
cae85a9
pre-commit
jmduarte Dec 22, 2022
61195a8
try more recent version of tf2onnx
jmduarte Dec 22, 2022
3b11ebf
add pyparsing
jmduarte Dec 22, 2022
6c438be
revert to tf2onnx==1.9.2
jmduarte Dec 22, 2022
ad2e884
fix for tf2onnx>=1.93
jmduarte Dec 22, 2022
e716b00
np.int -> int
jmduarte Dec 22, 2022
4d08653
mv test file
jmduarte Dec 22, 2022
a4ed110
updating pull-request
Jan 23, 2023
85ee643
resolve dependency conflict
Jan 24, 2023
47b03c5
add re-seed
Jan 24, 2023
ffee1f4
fixing reseed
Jan 25, 2023
95df5cf
Merge branch 'main' into keras2qonnx_dev
jmduarte Feb 1, 2023
89124d1
update
jmduarte Feb 1, 2023
7e209ce
more linting
jmduarte Feb 1, 2023
f9a5c9b
mv pyparsing
jmduarte Feb 1, 2023
dba9821
rm stuff from init
jmduarte Feb 1, 2023
367e085
wrap in try except
jmduarte Feb 1, 2023
5c1c544
Merge branch 'main' into keras2qonnx_dev
jmduarte Feb 1, 2023
a58da45
Update setup.cfg
jmduarte Feb 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
name: Lint PR or Push to DEV
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.8]

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[testing]
pip install -e .[testing,qkeras]

- name: Run tests
run: |
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,6 @@ dmypy.json

# Cython debug symbols
cython_debug/

# IDE stuff
.vscode
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -21,19 +21,19 @@ repos:
args: ['--fix=no']

- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/psf/black
rev: 22.3.0
rev: 23.1.0
hooks:
- id: black
language_version: python3
args: [--line-length=125]

- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
# black-compatible flake-8 config
Expand Down
9 changes: 8 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ exclude =
# Add here additional requirements for extra features, to install with:
# `pip install qonnx[PDF]` like:
# PDF = ReportLab; RXP
# Note: tf2onnx 1.12.1 is the first version that supports numpy 1.24.1
# Note: pyparsing is actually needed by QKeras, but missing as dependency
qkeras =
pyparsing
tf2onnx>=1.12.1
tensorflow==2.7.0
QKeras==0.9.0

# Add here test requirements (semicolon/line-separated)
testing =
Expand All @@ -84,7 +91,7 @@ console_scripts =
qonnx-to-channels-last = qonnx.util.to_channels_last:main
qonnx-inference-cost = qonnx.util.inference_cost:main
pytest_randomly.random_seeder =
qonnx = qonnx:reseed
qonnx = qonnx.util.random_reseed:reseed
# Add here console scripts like:
# console_scripts =
# script_name = qonnx.module:function
Expand Down
5 changes: 0 additions & 5 deletions src/qonnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
import onnxruntime


def reseed(newseed):
onnxruntime.set_seed(newseed)
4 changes: 4 additions & 0 deletions src/qonnx/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
try:
from .keras import from_keras # noqa: F401
except ImportError:
pass
261 changes: 261 additions & 0 deletions src/qonnx/converters/keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import onnx
import tensorflow as tf
import tf2onnx
from collections import OrderedDict
from qkeras.utils import REGISTERED_LAYERS as QKERAS_LAYERS

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup_model

from .qkeras.onnx import get_qkeras_onnx_handlers
from .qkeras.qlayers import extract_quantizers_from_layer

_unsupported_layers = [
# These require some extra work
"QBatchNormalization",
"QConv2DBatchnorm",
"QDepthwiseConv2DBatchnorm",
]

# Skip remove_identity optimizer
del tf2onnx.optimizer._optimizers["remove_identity"]


def add_value_info_for_constants(model: onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return model

def add_const_value_infos_to_graph(graph: onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue

# The details we want to add
elem_type = init.data_type
shape = init.dims

# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name

# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim

# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue

if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)

add_const_value_infos_to_graph(model.graph)
return model


def _is_qkeras_model(model):
"""Check if the model has any qkeras layers, so we can handle the qkeras layers separately

Args:
model: the model we want to convert

Returns:
True if the model contains any qkeras layer
"""

def iterate_model(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model):
found_qkeras = iterate_model(layer)
if found_qkeras:
return True
elif layer.__class__.__name__ in QKERAS_LAYERS:
return True

return False

return iterate_model(model)


def _check_supported_layers(model):
"""Check if all the layers in the model are supported for conversion

Args:
model: the tf.keras model we want to convert

Returns:
Exception if an unsupported layer is found in the model
"""

def iterate_model(model):
for layer in model.layers:
if isinstance(layer, tf.keras.Model):
iterate_model(layer)
elif layer.__class__.__name__ in _unsupported_layers:
raise Exception("Currently unsupported layer found in QKeras model: {}".format(layer.__class__.__name__))

iterate_model(model)


def _strip_qkeras_model(model):
"""Strip a qkeras model to obtain the keras model and obtain the quant nodes.

Args:
model: the tf.keras model we want to convert

Returns:
The stripped model, and the quantizers in a dictionary format
"""
quantizers = OrderedDict()

def extract_quantizers(layer):
keras_cls_name, layer_cfg, layer_quantizers = extract_quantizers_from_layer(layer)
if layer_quantizers:
layer_quantizers = {
k: None if v == "None" else v for k, v in layer_quantizers.items()
} # Get rid of 'None' strings
layer_quantizers["input"] = layer.input.name
quantizers[layer.name] = layer_quantizers

layer_class = tf.keras.layers.__dict__.get(keras_cls_name, None)
if layer_class is None:
raise Exception("Cannot create Keras layer from QKeras class {}".format(keras_cls_name))

return layer_class.from_config(layer_cfg)

stripped_model = tf.keras.models.clone_model(model, clone_function=extract_quantizers)
stripped_model.set_weights(model.get_weights())
return stripped_model, quantizers


# tests run without this function
def _convert_quantizers_to_nodes(onnx_model, quantizers_dict):
for node_name, quantizers in quantizers_dict.items():
print(node_name, quantizers)

for n in onnx_model.graph.node:
print(n)

return onnx_model.model


def from_keras(
model,
name="qkeras_to_qonnx_converted",
input_signature=None,
opset=None,
custom_ops=None,
custom_op_handlers=None,
custom_rewriter=None,
inputs_as_nchw=None,
extra_opset=None,
shape_override=None,
target=None,
large_model=False,
output_path=None,
):
"""Convert a keras model to QONNX. The API follows the `from_keras` function of tf2onnx.

Args:
model: the tf.keras model we want to convert
input_signature: a tf.TensorSpec or a numpy array defining the shape/dtype of the input
opset: the opset to be used for the ONNX model, default is the latest
custom_ops: if a model contains ops not recognized by onnx runtime,
you can tag these ops with a custom op domain so that the
runtime can still open the model. Type is a dictionary `{op name: domain}`.
target: list of workarounds applied to help certain platforms
custom_op_handlers: dictionary of custom ops handlers
custom_rewriter: list of custom graph rewriters
extra_opset: list of extra opset's, for example the opset's used by custom ops
shape_override: dict with inputs that override the shapes given by tensorflow
inputs_as_nchw: transpose inputs in list from nchw to nhwc
large_model: use the ONNX external tensor storage format
output_path: save model to output_path

Returns:
An ONNX model_proto and an external_tensor_storage dict.
"""

assert not large_model # TODO for now, let's focus only on models that don't store tensors externally

if _is_qkeras_model(model):
_check_supported_layers(model)
keras_model, quantizers = _strip_qkeras_model(model)
else:
keras_model, quantizers = model, {}

qkeras_op_handlers = get_qkeras_onnx_handlers(quantizers)

if custom_op_handlers is not None:
qkeras_op_handlers.update(custom_op_handlers)

model_proto, external_storage = tf2onnx.convert.from_keras(
keras_model,
input_signature=input_signature,
opset=opset,
custom_ops=custom_ops,
custom_op_handlers=qkeras_op_handlers,
custom_rewriter=custom_rewriter,
inputs_as_nchw=inputs_as_nchw,
extra_opset=extra_opset,
shape_override=shape_override,
target=target,
large_model=large_model,
output_path=None,
)

onnx_model = ModelWrapper(model_proto)
# Set the first value of input/output shape to 1, currently this is set to unknown,
# because it is technically the batch size
if not (len(onnx_model.graph.input) == 1 and len(onnx_model.graph.output) == 1):
raise ValueError("Qkeras to QONNX conversion only supports models with exactly one input and output.")
inp_shape = onnx_model.get_tensor_shape(onnx_model.graph.input[0].name)
out_shape = onnx_model.get_tensor_shape(onnx_model.graph.output[0].name)
inp_shape[0] = 1
out_shape[0] = 1
onnx_model.set_tensor_shape(onnx_model.graph.input[0].name, inp_shape)
onnx_model.set_tensor_shape(onnx_model.graph.output[0].name, out_shape)

# Set all Quant output tensors to float32 datatype, otherwise they are undefined and crash ONNX execution
qonnx_domain_ops = ["Quant", "Trunc", "BipolarQuant"]
for q_op_type in qonnx_domain_ops:
quant_nodes = onnx_model.get_nodes_by_op_type(q_op_type)
q_node_outputs = [qn.output[0] for qn in quant_nodes]
for tensor in onnx_model.graph.value_info:
if tensor.name in q_node_outputs:
tensor.type.tensor_type.elem_type = 1

onnx_model.save(f"tmp_{name}.onnx")

onnx_model = cleanup_model(onnx_model)
onnx_model.model = add_value_info_for_constants(onnx_model.model)

if output_path is not None:
onnx_model.save(output_path)

return onnx_model.model, external_storage
Empty file.
Loading