|
| 1 | +import pathlib |
| 2 | +import tempfile |
| 3 | + |
| 4 | +from keras.src import backend |
| 5 | +from keras.src import tree |
| 6 | +from keras.src.export.export_utils import convert_spec_to_tensor |
| 7 | +from keras.src.export.export_utils import get_input_signature |
| 8 | +from keras.src.export.saved_model import export_saved_model |
| 9 | +from keras.src.utils.module_utils import tensorflow as tf |
| 10 | + |
| 11 | + |
| 12 | +def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs): |
| 13 | + """Export the model as a ONNX artifact for inference. |
| 14 | +
|
| 15 | + This method lets you export a model to a lightweight ONNX artifact |
| 16 | + that contains the model's forward pass only (its `call()` method) |
| 17 | + and can be served via e.g. ONNX Runtime. |
| 18 | +
|
| 19 | + The original code of the model (including any custom layers you may |
| 20 | + have used) is *no longer* necessary to reload the artifact -- it is |
| 21 | + entirely standalone. |
| 22 | +
|
| 23 | + Args: |
| 24 | + filepath: `str` or `pathlib.Path` object. The path to save the artifact. |
| 25 | + verbose: `bool`. Whether to print a message during export. Defaults to |
| 26 | + True`. |
| 27 | + input_signature: Optional. Specifies the shape and dtype of the model |
| 28 | + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, |
| 29 | + `backend.KerasTensor`, or backend tensor. If not provided, it will |
| 30 | + be automatically computed. Defaults to `None`. |
| 31 | + **kwargs: Additional keyword arguments. |
| 32 | +
|
| 33 | + **Note:** This feature is currently supported only with TensorFlow, JAX and |
| 34 | + Torch backends. |
| 35 | +
|
| 36 | + **Note:** The dtype policy must be "float32" for the model. You can further |
| 37 | + optimize the ONNX artifact using the ONNX toolkit. Learn more here: |
| 38 | + [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/). |
| 39 | +
|
| 40 | + **Note:** The dynamic shape feature is not yet supported with Torch |
| 41 | + backend. As a result, you must fully define the shapes of the inputs using |
| 42 | + `input_signature`. If `input_signature` is not provided, all instances of |
| 43 | + `None` (such as the batch size) will be replaced with `1`. |
| 44 | +
|
| 45 | + Example: |
| 46 | +
|
| 47 | + ```python |
| 48 | + # Export the model as a ONNX artifact |
| 49 | + model.export("path/to/location", format="onnx") |
| 50 | +
|
| 51 | + # Load the artifact in a different process/environment |
| 52 | + ort_session = onnxruntime.InferenceSession("path/to/location") |
| 53 | + ort_inputs = { |
| 54 | + k.name: v for k, v in zip(ort_session.get_inputs(), input_data) |
| 55 | + } |
| 56 | + predictions = ort_session.run(None, ort_inputs) |
| 57 | + ``` |
| 58 | + """ |
| 59 | + if input_signature is None: |
| 60 | + input_signature = get_input_signature(model) |
| 61 | + if not input_signature or not model._called: |
| 62 | + raise ValueError( |
| 63 | + "The model provided has never called. " |
| 64 | + "It must be called at least once before export." |
| 65 | + ) |
| 66 | + |
| 67 | + if backend.backend() in ("tensorflow", "jax"): |
| 68 | + working_dir = pathlib.Path(filepath).parent |
| 69 | + with tempfile.TemporaryDirectory(dir=working_dir) as temp_dir: |
| 70 | + if backend.backend() == "jax": |
| 71 | + kwargs = _check_jax_kwargs(kwargs) |
| 72 | + export_saved_model( |
| 73 | + model, |
| 74 | + temp_dir, |
| 75 | + verbose, |
| 76 | + input_signature, |
| 77 | + **kwargs, |
| 78 | + ) |
| 79 | + saved_model_to_onnx(temp_dir, filepath, model.name) |
| 80 | + |
| 81 | + elif backend.backend() == "torch": |
| 82 | + import torch |
| 83 | + |
| 84 | + sample_inputs = tree.map_structure( |
| 85 | + lambda x: convert_spec_to_tensor(x, replace_none_number=1), |
| 86 | + input_signature, |
| 87 | + ) |
| 88 | + sample_inputs = tuple(sample_inputs) |
| 89 | + # TODO: Make dict model exportable. |
| 90 | + if any(isinstance(x, dict) for x in sample_inputs): |
| 91 | + raise ValueError( |
| 92 | + "Currently, `export_onnx` in the torch backend doesn't support " |
| 93 | + "dictionaries as inputs." |
| 94 | + ) |
| 95 | + |
| 96 | + # Convert to ONNX using TorchScript-based ONNX Exporter. |
| 97 | + # TODO: Use TorchDynamo-based ONNX Exporter once |
| 98 | + # `torch.onnx.dynamo_export()` supports Keras models. |
| 99 | + torch.onnx.export(model, sample_inputs, filepath, verbose=verbose) |
| 100 | + else: |
| 101 | + raise NotImplementedError( |
| 102 | + "`export_onnx` is only compatible with TensorFlow, JAX and " |
| 103 | + "Torch backends." |
| 104 | + ) |
| 105 | + |
| 106 | + |
| 107 | +def _check_jax_kwargs(kwargs): |
| 108 | + kwargs = kwargs.copy() |
| 109 | + if "is_static" not in kwargs: |
| 110 | + kwargs["is_static"] = True |
| 111 | + if "jax2tf_kwargs" not in kwargs: |
| 112 | + # TODO: These options will be deprecated in JAX. We need to |
| 113 | + # find another way to export ONNX. |
| 114 | + kwargs["jax2tf_kwargs"] = { |
| 115 | + "enable_xla": False, |
| 116 | + "native_serialization": False, |
| 117 | + } |
| 118 | + if kwargs["is_static"] is not True: |
| 119 | + raise ValueError( |
| 120 | + "`is_static` must be `True` in `kwargs` when using the jax " |
| 121 | + "backend." |
| 122 | + ) |
| 123 | + if kwargs["jax2tf_kwargs"]["enable_xla"] is not False: |
| 124 | + raise ValueError( |
| 125 | + "`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` " |
| 126 | + "when using the jax backend." |
| 127 | + ) |
| 128 | + if kwargs["jax2tf_kwargs"]["native_serialization"] is not False: |
| 129 | + raise ValueError( |
| 130 | + "`native_serialization` must be `False` in " |
| 131 | + "`kwargs['jax2tf_kwargs']` when using the jax backend." |
| 132 | + ) |
| 133 | + return kwargs |
| 134 | + |
| 135 | + |
| 136 | +def saved_model_to_onnx(saved_model_dir, filepath, name): |
| 137 | + from keras.src.utils.module_utils import tf2onnx |
| 138 | + |
| 139 | + # Convert to ONNX using `tf2onnx` library. |
| 140 | + (graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = ( |
| 141 | + tf2onnx.tf_loader.from_saved_model( |
| 142 | + saved_model_dir, |
| 143 | + None, |
| 144 | + None, |
| 145 | + return_initialized_tables=True, |
| 146 | + return_tensors_to_rename=True, |
| 147 | + ) |
| 148 | + ) |
| 149 | + |
| 150 | + with tf.device("/cpu:0"): |
| 151 | + _ = tf2onnx.convert._convert_common( |
| 152 | + graph_def, |
| 153 | + name=name, |
| 154 | + target=[], |
| 155 | + custom_op_handlers={}, |
| 156 | + extra_opset=[], |
| 157 | + input_names=inputs, |
| 158 | + output_names=outputs, |
| 159 | + tensors_to_rename=tensors_to_rename, |
| 160 | + initialized_tables=initialized_tables, |
| 161 | + output_path=filepath, |
| 162 | + ) |
0 commit comments