-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Add Tensorflow handling of ONNX conversion #13831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 250 commits
8e28ef9
577febd
87b6ac9
0f8502c
a942f6e
bf6901e
7166ac3
ea5379b
9c53194
9f91c60
5ae45e9
7c52751
95233cf
2cdcb95
1c39882
dd6a480
4a07840
0d557f0
56445bd
d06ca88
17c0422
861e7a6
80b91bb
995b2c4
648c3cd
091616d
7568bfb
ce7480a
6d1c080
6399167
482622f
e33c504
9aa9e8f
48870ec
83c7acb
cd30477
c788e32
8e3e4d7
e007aae
2716d63
5c57df0
afb1b3e
c627eca
2ef85df
be5e54b
20c43f2
36d9e3a
a348504
071b0be
f67c476
cbc7880
3a12d24
458f428
d05fd23
78436cb
a388707
c60ee09
575dcd0
507a2ca
088a610
a250e75
72c0e5a
02bafb3
726cecd
0dbf139
852884c
336fc02
19048ed
b14a59f
82ab4bf
e8d789f
114a940
d6ab747
2279d46
244aaaf
0eb6d69
9c90171
9fd3d98
4142bac
7d10dbf
8f3f725
6a9f474
9c2a1ac
61c18f9
3befc00
41f987c
bf81baf
353915d
9a738d6
edcefc9
a3e6e65
bac4566
38f74d9
3cb0ffc
caa166d
f0a2374
d93bf3f
eed928c
226576c
2f4d7b9
bf3b4e3
9a6f8a0
a3c9be1
a3689dd
85227a8
51b5f5a
a4b5219
26af612
120ed8c
17073ba
545663b
85212f0
eeca7e4
d8f0486
86c9852
b1933d5
f24b430
52c780d
59ba84c
64fbbe0
8ef1f8e
11fd216
0f64f4d
85a5772
c2873d5
debac54
118cc26
e8bcddf
b53c3c8
3e75cdd
9455666
297abeb
30bb7cc
e7bbc82
9323670
4296514
22aa017
5976de7
6795780
635a3bd
40081ba
38f089f
0c17cea
e5902d2
0c98bdf
59db73c
dd859ae
f4339ee
1e4c3f8
f172978
6a4d553
39b4223
3228914
36d0eb4
cde0a3e
8a31277
4a9de35
3e0ca20
1747d36
ddd5ca7
b280c18
e7761c9
763b933
2608566
7fa9cd2
bbdebd7
b3f48cb
03d4b51
21ff6a0
9c405e9
eed0e28
eae526d
1a030d1
12dcb3f
ee2d797
dafa760
7e507c7
1cb5545
c8255a6
73ac1bf
fa58538
03d769f
da5865e
5f6ad60
71d6dc0
7965998
aaf11f9
847fc2a
448dc85
29259ad
457ed1c
18ceeb6
f17631a
9da68a0
2fd3803
a691b79
4c2e966
673b626
0b52bbf
022cb07
670456f
15fc0bd
1c51a90
bbf6d9d
430859f
228cee9
33e9cdd
3a585ef
822a090
35d57e8
c6c0d64
4193e90
374e00f
ce27774
ffb9284
f1dc7e0
6ab78a5
da7e8e5
c58eb18
b2266db
8b3353d
b3184b0
6f23e41
a65c8bd
04f25a0
9bab135
0cb896e
92f86a8
87859df
0cd8783
893499b
4bbb806
2821592
c0a0456
1877023
2315244
019083c
394a35f
d4ffce7
79d4945
b1e50f6
79367d3
69b5a23
94107d8
071c21e
d201dce
b605bb6
b4816f9
42d0270
fe6c0d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |
| from packaging.version import Version, parse | ||
|
|
||
| from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available | ||
| from transformers.file_utils import is_torch_onnx_dict_inputs_support_available | ||
| from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available | ||
| from transformers.onnx.config import OnnxConfig | ||
| from transformers.utils import logging | ||
|
|
||
|
|
@@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version): | |
| ) | ||
|
|
||
|
|
||
| def export( | ||
| tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path | ||
| def export_pytorch( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: we decided to refactor the single
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as the method |
||
| tokenizer: PreTrainedTokenizer, | ||
| model: PreTrainedModel, | ||
| config: OnnxConfig, | ||
| opset: int, | ||
| output: Path, | ||
| ) -> Tuple[List[str], List[str]]: | ||
| """ | ||
| Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR | ||
| Export a PyTorch model to an ONNX Intermediate Representation (IR) | ||
|
|
||
| Args: | ||
| tokenizer: | ||
| model: | ||
| config: | ||
| opset: | ||
| output: | ||
| tokenizer ([`PreTrainedTokenizer`]): | ||
| The tokenizer used for encoding the data. | ||
| model ([`PreTrainedModel`]): | ||
| The model to export. | ||
| config ([`~onnx.config.OnnxConfig`]): | ||
| The ONNX configuration associated with the exported model. | ||
| opset (`int`): | ||
| The version of the ONNX operator set to use. | ||
| output (`Path`): | ||
| Directory to store the exported ONNX model. | ||
|
|
||
| Returns: | ||
| `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from | ||
| the ONNX configuration. | ||
| """ | ||
| if is_torch_available(): | ||
Albertobegue marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from transformers.file_utils import torch_version | ||
|
|
||
| if not is_torch_onnx_dict_inputs_support_available(): | ||
| raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") | ||
|
|
||
| if issubclass(type(model), PreTrainedModel): | ||
| import torch | ||
| from torch.onnx import export | ||
Albertobegue marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| logger.info(f"Using framework PyTorch: {torch.__version__}") | ||
| with torch.no_grad(): | ||
| model.config.return_dict = True | ||
| model.eval() | ||
|
|
||
| # Check if we need to override certain configuration item | ||
| if config.values_override is not None: | ||
| logger.info(f"Overriding {len(config.values_override)} configuration item(s)") | ||
| for override_config_key, override_config_value in config.values_override.items(): | ||
| logger.info(f"\t- {override_config_key} -> {override_config_value}") | ||
| setattr(model.config, override_config_key, override_config_value) | ||
|
|
||
| # Ensure inputs match | ||
| # TODO: Check when exporting QA we provide "is_pair=True" | ||
| model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) | ||
| inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) | ||
| onnx_outputs = list(config.outputs.keys()) | ||
|
|
||
| if not inputs_match: | ||
| raise ValueError("Model and config inputs doesn't match") | ||
|
|
||
| config.patch_ops() | ||
|
|
||
| # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, | ||
| # so we check the torch version for backwards compatibility | ||
| if parse(torch.__version__) <= parse("1.10.99"): | ||
| # export can work with named args but the dict containing named args | ||
| # has to be the last element of the args tuple. | ||
| export( | ||
| model, | ||
| (model_inputs,), | ||
| f=output.as_posix(), | ||
| input_names=list(config.inputs.keys()), | ||
| output_names=onnx_outputs, | ||
| dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, | ||
| do_constant_folding=True, | ||
| use_external_data_format=config.use_external_data_format(model.num_parameters()), | ||
| enable_onnx_checker=True, | ||
| opset_version=opset, | ||
| ) | ||
| else: | ||
| export( | ||
| model, | ||
| (model_inputs,), | ||
| f=output.as_posix(), | ||
| input_names=list(config.inputs.keys()), | ||
| output_names=onnx_outputs, | ||
| dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, | ||
| do_constant_folding=True, | ||
| opset_version=opset, | ||
| ) | ||
|
|
||
| config.restore_ops() | ||
|
|
||
| return matched_inputs, onnx_outputs | ||
|
|
||
|
|
||
| def export_tensorflow( | ||
| tokenizer: PreTrainedTokenizer, | ||
| model: TFPreTrainedModel, | ||
| config: OnnxConfig, | ||
| opset: int, | ||
| output: Path, | ||
| ) -> Tuple[List[str], List[str]]: | ||
| """ | ||
| Export a TensorFlow model to an ONNX Intermediate Representation (IR) | ||
|
|
||
| Args: | ||
| tokenizer ([`PreTrainedTokenizer`]): | ||
| The tokenizer used for encoding the data. | ||
| model ([`TFPreTrainedModel`]): | ||
| The model to export. | ||
| config ([`~onnx.config.OnnxConfig`]): | ||
| The ONNX configuration associated with the exported model. | ||
| opset (`int`): | ||
| The version of the ONNX operator set to use. | ||
| output (`Path`): | ||
| Directory to store the exported ONNX model. | ||
|
|
||
| Returns: | ||
| `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from | ||
| the ONNX configuration. | ||
| """ | ||
| if not is_torch_available(): | ||
| raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.") | ||
|
|
||
| import torch | ||
| from torch.onnx import export | ||
|
|
||
| from ..file_utils import torch_version | ||
|
|
||
| if not is_torch_onnx_dict_inputs_support_available(): | ||
| raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") | ||
|
|
||
| logger.info(f"Using framework PyTorch: {torch.__version__}") | ||
| with torch.no_grad(): | ||
| model.config.return_dict = True | ||
| model.eval() | ||
|
|
||
| # Check if we need to override certain configuration item | ||
| if config.values_override is not None: | ||
| logger.info(f"Overriding {len(config.values_override)} configuration item(s)") | ||
| for override_config_key, override_config_value in config.values_override.items(): | ||
| logger.info(f"\t- {override_config_key} -> {override_config_value}") | ||
| setattr(model.config, override_config_key, override_config_value) | ||
|
|
||
| # Ensure inputs match | ||
| # TODO: Check when exporting QA we provide "is_pair=True" | ||
| model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) | ||
| inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) | ||
| onnx_outputs = list(config.outputs.keys()) | ||
|
|
||
| if not inputs_match: | ||
| raise ValueError("Model and config inputs doesn't match") | ||
|
|
||
| config.patch_ops() | ||
|
|
||
| # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, | ||
| # so we check the torch version for backwards compatibility | ||
| if parse(torch.__version__) <= parse("1.10.99"): | ||
| # export can work with named args but the dict containing named args | ||
| # has to be the last element of the args tuple. | ||
| export( | ||
| model, | ||
| (model_inputs,), | ||
| f=output.as_posix(), | ||
| input_names=list(config.inputs.keys()), | ||
| output_names=onnx_outputs, | ||
| dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, | ||
| do_constant_folding=True, | ||
| use_external_data_format=config.use_external_data_format(model.num_parameters()), | ||
| enable_onnx_checker=True, | ||
| opset_version=opset, | ||
| ) | ||
| else: | ||
| export( | ||
| model, | ||
| (model_inputs,), | ||
| f=output.as_posix(), | ||
| input_names=list(config.inputs.keys()), | ||
| output_names=onnx_outputs, | ||
| dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, | ||
| do_constant_folding=True, | ||
| opset_version=opset, | ||
| ) | ||
| import tensorflow as tf | ||
|
|
||
| import onnx | ||
| import tf2onnx | ||
|
|
||
| model.config.return_dict = True | ||
|
|
||
| config.restore_ops() | ||
| # Check if we need to override certain configuration item | ||
| if config.values_override is not None: | ||
| logger.info(f"Overriding {len(config.values_override)} configuration item(s)") | ||
| for override_config_key, override_config_value in config.values_override.items(): | ||
| logger.info(f"\t- {override_config_key} -> {override_config_value}") | ||
| setattr(model.config, override_config_key, override_config_value) | ||
|
|
||
| # Ensure inputs match | ||
| model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) | ||
| inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) | ||
| onnx_outputs = list(config.outputs.keys()) | ||
|
|
||
| input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()] | ||
| onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset) | ||
| onnx.save(onnx_model, output.as_posix()) | ||
| config.restore_ops() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think ops patching might need some changes to work for both backends, but I guess that can be figured out later. |
||
|
|
||
| return matched_inputs, onnx_outputs | ||
|
|
||
|
|
||
| def export( | ||
| tokenizer: PreTrainedTokenizer, | ||
| model: Union[PreTrainedModel, TFPreTrainedModel], | ||
| config: OnnxConfig, | ||
| opset: int, | ||
| output: Path, | ||
| ) -> Tuple[List[str], List[str]]: | ||
| """ | ||
| Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) | ||
|
|
||
| Args: | ||
Albertobegue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tokenizer ([`PreTrainedTokenizer`]): | ||
| The tokenizer used for encoding the data. | ||
| model ([`PreTrainedModel`] or [`TFPreTrainedModel`]): | ||
| The model to export. | ||
| config ([`~onnx.config.OnnxConfig`]): | ||
| The ONNX configuration associated with the exported model. | ||
| opset (`int`): | ||
| The version of the ONNX operator set to use. | ||
| output (`Path`): | ||
| Directory to store the exported ONNX model. | ||
|
|
||
| Returns: | ||
| `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from | ||
| the ONNX configuration. | ||
| """ | ||
| if not (is_torch_available() or is_tf_available()): | ||
| raise ImportError( | ||
| "Cannot convert because neither PyTorch nor TensorFlow are not installed. " | ||
| "Please install torch or tensorflow first." | ||
| ) | ||
|
|
||
| if is_torch_available(): | ||
| return export_pytorch(tokenizer, model, config, opset, output) | ||
| elif is_tf_available(): | ||
| return export_tensorflow(tokenizer, model, config, opset, output) | ||
Albertobegue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def validate_model_outputs( | ||
| config: OnnxConfig, | ||
| tokenizer: PreTrainedTokenizer, | ||
|
|
@@ -160,7 +260,10 @@ def validate_model_outputs( | |
|
|
||
| # TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test | ||
| # dynamic input shapes. | ||
| reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) | ||
| if issubclass(type(reference_model), PreTrainedModel): | ||
| reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) | ||
| else: | ||
| reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW) | ||
|
|
||
| # Create ONNX Runtime session | ||
| options = SessionOptions() | ||
|
|
@@ -210,7 +313,10 @@ def validate_model_outputs( | |
|
|
||
| # Check the shape and values match | ||
| for name, ort_value in zip(onnx_named_outputs, onnx_outputs): | ||
| ref_value = ref_outputs_dict[name].detach().numpy() | ||
| if issubclass(type(reference_model), PreTrainedModel): | ||
| ref_value = ref_outputs_dict[name].detach().numpy() | ||
| else: | ||
| ref_value = ref_outputs_dict[name].numpy() | ||
| logger.info(f'\t- Validating ONNX Model output "{name}":') | ||
|
|
||
| # Shape | ||
|
|
@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match( | |
|
|
||
| :param model_inputs: :param config_inputs: :return: | ||
| """ | ||
| forward_parameters = signature(model.forward).parameters | ||
| if issubclass(type(model), PreTrainedModel): | ||
| forward_parameters = signature(model.forward).parameters | ||
| else: | ||
| forward_parameters = signature(model.call).parameters | ||
| model_inputs_set = set(model_inputs) | ||
|
|
||
| # We are fine if config_inputs has more keys than model_inputs | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.