Skip to content

Commit 94977dd

Browse files
Refactor keras/src/export/export_lib and add export_onnx (#20710)
* Refactor export_lib and add export_onnx Add tf2onnx requirements * Add onnxruntime dep * Update numpy dep * Resolve comments
1 parent 41c429e commit 94977dd

26 files changed

+943
-447
lines changed

.kokoro/github/ubuntu/gpu/build.sh

-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ then
7272
# Raise error if GPU is not detected.
7373
python3 -c 'import torch;assert torch.cuda.is_available()'
7474

75-
# TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH
7675
pytest keras --ignore keras/src/applications \
7776
--cov=keras \
7877
--cov-config=pyproject.toml

keras/api/_tf_keras/keras/export/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras.src.export.export_lib import ExportArchive
7+
from keras.src.export.saved_model import ExportArchive

keras/api/_tf_keras/keras/layers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras.src.export.export_lib import TFSMLayer
7+
from keras.src.export.tfsm_layer import TFSMLayer
88
from keras.src.layers import deserialize
99
from keras.src.layers import serialize
1010
from keras.src.layers.activations.activation import Activation

keras/api/export/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras.src.export.export_lib import ExportArchive
7+
from keras.src.export.saved_model import ExportArchive

keras/api/layers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras.src.export.export_lib import TFSMLayer
7+
from keras.src.export.tfsm_layer import TFSMLayer
88
from keras.src.layers import deserialize
99
from keras.src.layers import serialize
1010
from keras.src.layers.activations.activation import Activation

keras/src/backend/torch/export.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
import torch
55

6-
from keras.src import backend
7-
from keras.src import ops
86
from keras.src import tree
7+
from keras.src.export.export_utils import convert_spec_to_tensor
98
from keras.src.utils.module_utils import tensorflow as tf
109
from keras.src.utils.module_utils import torch_xla
1110

@@ -36,23 +35,10 @@ def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
3635
f"Received: resource={resource} (of type {type(resource)})"
3736
)
3837

39-
def _check_input_signature(input_spec):
40-
for s in tree.flatten(input_spec.shape):
41-
if s is None:
42-
raise ValueError(
43-
"The shape in the `input_spec` must be fully "
44-
f"specified. Received: input_spec={input_spec}"
45-
)
46-
47-
def _to_torch_tensor(x, replace_none_number=1):
48-
shape = backend.standardize_shape(x.shape)
49-
shape = tuple(
50-
s if s is not None else replace_none_number for s in shape
51-
)
52-
return ops.ones(shape, x.dtype)
53-
54-
tree.map_structure(_check_input_signature, input_signature)
55-
sample_inputs = tree.map_structure(_to_torch_tensor, input_signature)
38+
sample_inputs = tree.map_structure(
39+
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
40+
input_signature,
41+
)
5642
sample_inputs = tuple(sample_inputs)
5743

5844
# Ref: torch_xla.tf_saved_model_integration

keras/src/export/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from keras.src.export.export_lib import ExportArchive
1+
from keras.src.export.onnx import export_onnx
2+
from keras.src.export.saved_model import ExportArchive
3+
from keras.src.export.saved_model import export_saved_model
4+
from keras.src.export.tfsm_layer import TFSMLayer

keras/src/export/export_utils.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from keras.src import backend
2+
from keras.src import layers
3+
from keras.src import models
4+
from keras.src import ops
5+
from keras.src import tree
6+
from keras.src.utils.module_utils import tensorflow as tf
7+
8+
9+
def get_input_signature(model):
10+
if not isinstance(model, models.Model):
11+
raise TypeError(
12+
"The model must be a `keras.Model`. "
13+
f"Received: model={model} of the type {type(model)}"
14+
)
15+
if not model.built:
16+
raise ValueError(
17+
"The model provided has not yet been built. It must be built "
18+
"before export."
19+
)
20+
if isinstance(model, (models.Functional, models.Sequential)):
21+
input_signature = tree.map_structure(make_input_spec, model.inputs)
22+
if isinstance(input_signature, list) and len(input_signature) > 1:
23+
input_signature = [input_signature]
24+
else:
25+
input_signature = _infer_input_signature_from_model(model)
26+
if not input_signature or not model._called:
27+
raise ValueError(
28+
"The model provided has never called. "
29+
"It must be called at least once before export."
30+
)
31+
return input_signature
32+
33+
34+
def _infer_input_signature_from_model(model):
35+
shapes_dict = getattr(model, "_build_shapes_dict", None)
36+
if not shapes_dict:
37+
return None
38+
39+
def _make_input_spec(structure):
40+
# We need to turn wrapper structures like TrackingDict or _DictWrapper
41+
# into plain Python structures because they don't work with jax2tf/JAX.
42+
if isinstance(structure, dict):
43+
return {k: _make_input_spec(v) for k, v in structure.items()}
44+
elif isinstance(structure, tuple):
45+
if all(isinstance(d, (int, type(None))) for d in structure):
46+
return layers.InputSpec(
47+
shape=(None,) + structure[1:], dtype=model.input_dtype
48+
)
49+
return tuple(_make_input_spec(v) for v in structure)
50+
elif isinstance(structure, list):
51+
if all(isinstance(d, (int, type(None))) for d in structure):
52+
return layers.InputSpec(
53+
shape=[None] + structure[1:], dtype=model.input_dtype
54+
)
55+
return [_make_input_spec(v) for v in structure]
56+
else:
57+
raise ValueError(
58+
f"Unsupported type {type(structure)} for {structure}"
59+
)
60+
61+
return [_make_input_spec(value) for value in shapes_dict.values()]
62+
63+
64+
def make_input_spec(x):
65+
if isinstance(x, layers.InputSpec):
66+
if x.shape is None or x.dtype is None:
67+
raise ValueError(
68+
"The `shape` and `dtype` must be provided. " f"Received: x={x}"
69+
)
70+
input_spec = x
71+
elif isinstance(x, backend.KerasTensor):
72+
shape = (None,) + backend.standardize_shape(x.shape)[1:]
73+
dtype = backend.standardize_dtype(x.dtype)
74+
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name)
75+
elif backend.is_tensor(x):
76+
shape = (None,) + backend.standardize_shape(x.shape)[1:]
77+
dtype = backend.standardize_dtype(x.dtype)
78+
input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None)
79+
else:
80+
raise TypeError(
81+
f"Unsupported x={x} of the type ({type(x)}). Supported types are: "
82+
"`keras.InputSpec`, `keras.KerasTensor` and backend tensor."
83+
)
84+
return input_spec
85+
86+
87+
def make_tf_tensor_spec(x):
88+
if isinstance(x, tf.TensorSpec):
89+
tensor_spec = x
90+
else:
91+
input_spec = make_input_spec(x)
92+
tensor_spec = tf.TensorSpec(
93+
input_spec.shape, dtype=input_spec.dtype, name=input_spec.name
94+
)
95+
return tensor_spec
96+
97+
98+
def convert_spec_to_tensor(spec, replace_none_number=None):
99+
shape = backend.standardize_shape(spec.shape)
100+
if replace_none_number is not None:
101+
replace_none_number = int(replace_none_number)
102+
shape = tuple(
103+
s if s is not None else replace_none_number for s in shape
104+
)
105+
return ops.ones(shape, spec.dtype)

keras/src/export/onnx.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import pathlib
2+
import tempfile
3+
4+
from keras.src import backend
5+
from keras.src import tree
6+
from keras.src.export.export_utils import convert_spec_to_tensor
7+
from keras.src.export.export_utils import get_input_signature
8+
from keras.src.export.saved_model import export_saved_model
9+
from keras.src.utils.module_utils import tensorflow as tf
10+
11+
12+
def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs):
13+
"""Export the model as a ONNX artifact for inference.
14+
15+
This method lets you export a model to a lightweight ONNX artifact
16+
that contains the model's forward pass only (its `call()` method)
17+
and can be served via e.g. ONNX Runtime.
18+
19+
The original code of the model (including any custom layers you may
20+
have used) is *no longer* necessary to reload the artifact -- it is
21+
entirely standalone.
22+
23+
Args:
24+
filepath: `str` or `pathlib.Path` object. The path to save the artifact.
25+
verbose: `bool`. Whether to print a message during export. Defaults to
26+
True`.
27+
input_signature: Optional. Specifies the shape and dtype of the model
28+
inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
29+
`backend.KerasTensor`, or backend tensor. If not provided, it will
30+
be automatically computed. Defaults to `None`.
31+
**kwargs: Additional keyword arguments.
32+
33+
**Note:** This feature is currently supported only with TensorFlow, JAX and
34+
Torch backends.
35+
36+
**Note:** The dtype policy must be "float32" for the model. You can further
37+
optimize the ONNX artifact using the ONNX toolkit. Learn more here:
38+
[https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/).
39+
40+
**Note:** The dynamic shape feature is not yet supported with Torch
41+
backend. As a result, you must fully define the shapes of the inputs using
42+
`input_signature`. If `input_signature` is not provided, all instances of
43+
`None` (such as the batch size) will be replaced with `1`.
44+
45+
Example:
46+
47+
```python
48+
# Export the model as a ONNX artifact
49+
model.export("path/to/location", format="onnx")
50+
51+
# Load the artifact in a different process/environment
52+
ort_session = onnxruntime.InferenceSession("path/to/location")
53+
ort_inputs = {
54+
k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
55+
}
56+
predictions = ort_session.run(None, ort_inputs)
57+
```
58+
"""
59+
if input_signature is None:
60+
input_signature = get_input_signature(model)
61+
if not input_signature or not model._called:
62+
raise ValueError(
63+
"The model provided has never called. "
64+
"It must be called at least once before export."
65+
)
66+
67+
if backend.backend() in ("tensorflow", "jax"):
68+
working_dir = pathlib.Path(filepath).parent
69+
with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir:
70+
if backend.backend() == "jax":
71+
kwargs = _check_jax_kwargs(kwargs)
72+
export_saved_model(
73+
model,
74+
temp_dir,
75+
verbose,
76+
input_signature,
77+
**kwargs,
78+
)
79+
saved_model_to_onnx(temp_dir, filepath, model.name)
80+
81+
elif backend.backend() == "torch":
82+
import torch
83+
84+
sample_inputs = tree.map_structure(
85+
lambda x: convert_spec_to_tensor(x, replace_none_number=1),
86+
input_signature,
87+
)
88+
sample_inputs = tuple(sample_inputs)
89+
# TODO: Make dict model exportable.
90+
if any(isinstance(x, dict) for x in sample_inputs):
91+
raise ValueError(
92+
"Currently, `export_onnx` in the torch backend doesn't support "
93+
"dictionaries as inputs."
94+
)
95+
96+
# Convert to ONNX using TorchScript-based ONNX Exporter.
97+
# TODO: Use TorchDynamo-based ONNX Exporter once
98+
# `torch.onnx.dynamo_export()` supports Keras models.
99+
torch.onnx.export(model, sample_inputs, filepath, verbose=verbose)
100+
else:
101+
raise NotImplementedError(
102+
"`export_onnx` is only compatible with TensorFlow, JAX and "
103+
"Torch backends."
104+
)
105+
106+
107+
def _check_jax_kwargs(kwargs):
108+
kwargs = kwargs.copy()
109+
if "is_static" not in kwargs:
110+
kwargs["is_static"] = True
111+
if "jax2tf_kwargs" not in kwargs:
112+
# TODO: These options will be deprecated in JAX. We need to
113+
# find another way to export ONNX.
114+
kwargs["jax2tf_kwargs"] = {
115+
"enable_xla": False,
116+
"native_serialization": False,
117+
}
118+
if kwargs["is_static"] is not True:
119+
raise ValueError(
120+
"`is_static` must be `True` in `kwargs` when using the jax "
121+
"backend."
122+
)
123+
if kwargs["jax2tf_kwargs"]["enable_xla"] is not False:
124+
raise ValueError(
125+
"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` "
126+
"when using the jax backend."
127+
)
128+
if kwargs["jax2tf_kwargs"]["native_serialization"] is not False:
129+
raise ValueError(
130+
"`native_serialization` must be `False` in "
131+
"`kwargs['jax2tf_kwargs']` when using the jax backend."
132+
)
133+
return kwargs
134+
135+
136+
def saved_model_to_onnx(saved_model_dir, filepath, name):
137+
from keras.src.utils.module_utils import tf2onnx
138+
139+
# Convert to ONNX using `tf2onnx` library.
140+
(graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = (
141+
tf2onnx.tf_loader.from_saved_model(
142+
saved_model_dir,
143+
None,
144+
None,
145+
return_initialized_tables=True,
146+
return_tensors_to_rename=True,
147+
)
148+
)
149+
150+
with tf.device("/cpu:0"):
151+
_ = tf2onnx.convert._convert_common(
152+
graph_def,
153+
name=name,
154+
target=[],
155+
custom_op_handlers={},
156+
extra_opset=[],
157+
input_names=inputs,
158+
output_names=outputs,
159+
tensors_to_rename=tensors_to_rename,
160+
initialized_tables=initialized_tables,
161+
output_path=filepath,
162+
)

0 commit comments

Comments
 (0)