-
Notifications
You must be signed in to change notification settings - Fork 618
Add support ORT whisper #420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
michaelbenayoun
merged 15 commits into
huggingface:main
from
mht-sharma:add-support-onnx-whisper
Nov 15, 2022
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
69b0be2
added support onnxruntime whisper
mht-sharma b59ac10
Updated decoder export model
mht-sharma 945f235
Updated docstring
mht-sharma ccbafbc
updated tests for whisper
mht-sharma 88e6053
add whisper onnx configs
mht-sharma 3bf3c99
Added Whisper model to exporters
mht-sharma 8677559
Removed unused imports
mht-sharma 06632aa
Added tests for exporters and iobinding
mht-sharma 2b84fff
Removed redundant line
mht-sharma 9520897
Updated input generator and config
mht-sharma 4d6e313
Updatedtests
mht-sharma 065a93e
added sample audio input
mht-sharma 512c781
Removed redundant code to fix test
mht-sharma 14358a0
Updated iobinding
mht-sharma dbba1f9
Fix tests
mht-sharma File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC): | |
| "seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), | ||
| "sequence-classification": OrderedDict({"logits": {0: "batch_size"}}), | ||
| "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), | ||
| "speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), | ||
| } | ||
|
|
||
| def __init__( | ||
|
|
@@ -206,6 +207,17 @@ def is_torch_support_available(self) -> bool: | |
| return TORCH_VERSION >= self.MIN_TORCH_VERSION | ||
| return False | ||
|
|
||
| @property | ||
| def torch_to_onnx_input_map(self) -> Mapping[str, str]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would make it clear that it is needed when the dummy input names and the exported input names do not match.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the doctoring |
||
| """ | ||
| Dictionary of keys to update the ONNX input name for export. Override the function when | ||
| the dummy input names and the exported ONNX input names need to be different. | ||
|
|
||
| Returns: | ||
| `Mapping[str, str]`: A dictionary specifying the dummy input name to exported ONNX input name map. | ||
| """ | ||
| return {} | ||
|
|
||
| def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, str]]: | ||
| """ | ||
| Re-orders the inputs using the model forward pass signature. | ||
|
|
@@ -218,6 +230,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, | |
| `Mapping[str, Mappingp[int, str]]`: The properly ordered inputs. | ||
| """ | ||
| inputs = self.inputs | ||
|
|
||
| ordered_inputs = {} | ||
| sig = inspect.signature(model.forward) | ||
| for param in sig.parameters: | ||
|
|
@@ -229,6 +242,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, | |
| # TODO: figure out a smart way of re-ordering potential nested structures. | ||
| # to_insert = sorted(to_insert, key=lambda t: t[0]) | ||
| for name, dynamic_axes in to_insert: | ||
| name = self.torch_to_onnx_input_map.get(name, name) | ||
| ordered_inputs[name] = dynamic_axes | ||
| return ordered_inputs | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it the "official" name?
We could take:
automatic-speech-recognitionto match the pipelinesspeech2text@lewtun wdty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea was to partially align with the underlying autoclass, but I agree
automatic-speech-recognitionwould be more intuitive.In general (not for this PR), I think we should take the opportunity to align more closely with the Hub tasks, e.g.
seq2seq-lmcould also betext2text-generationright?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright then I guess we can keep
speech2seq-lmfor now since the other names are aligned to the AutoClass, and maybe change that (if needed) for all the tasks in another PR.