-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Add ONNX support for MarianMT models #14586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…like BartModel.forward()
Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations.
|
There seems to be some sort of race condition happening in This issue has similar problems - perhaps a solution lies there. |
docs/source/serialization.rst
Outdated
| - GPT Neo | ||
| - LayoutLM | ||
| - Longformer | ||
| - Marian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether .rst files are still allowed with the new .mdx doc - does this need updating / changing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Letting @LysandreJik answering this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the Sylvain recently converted all the RST files to MDX, so I'll rebase and this file should disappear :)
| self._setup_normalizer() | ||
|
|
||
| def num_special_tokens_to_add(self, **unused): | ||
| def num_special_tokens_to_add(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is required to accommodate the use of positional arguments like tokenizer.num_special_tokens_to_add(is_pair) in _generate_dummy_inputs_for_sequence_classification_and_question_answering().
I'm not sure why we had **unused in the first place, but the change also seems more conventional IMO.
| ] | ||
| return common_inputs | ||
|
|
||
| def _generate_dummy_inputs_for_sequence_classification_and_question_answering( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, Marian doesn't have heads for sequence classification or question answering and this function is here due to the copy-paste from the BART config.
If you think this is confusing, I can remove this function and refactor the other dummy generation functions accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be done, you'll just have to remove the # Copied from comment at the top of the class declaration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could remove the # Copied from which is at the top of the class declaration and add it only to methods. It supports methods as well as classes.
| ) | ||
|
|
||
|
|
||
| # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->Marian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the Marian model is copied from BART (see modeling_marian.py), I adopted a similar approach for the ONNX config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, nice!
docs/source/serialization.rst
Outdated
| - GPT Neo | ||
| - LayoutLM | ||
| - Longformer | ||
| - Marian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Letting @LysandreJik answering this one.
| ] | ||
| return common_inputs | ||
|
|
||
| def _generate_dummy_inputs_for_sequence_classification_and_question_answering( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be done, you'll just have to remove the # Copied from comment at the top of the class declaration.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you @lewtun!
| ) | ||
|
|
||
|
|
||
| # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->Marian |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, nice!
| ] | ||
| return common_inputs | ||
|
|
||
| def _generate_dummy_inputs_for_sequence_classification_and_question_answering( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could remove the # Copied from which is at the top of the class declaration and add it only to methods. It supports methods as well as classes.
|
Feel free to merge once you have taken care of the docs and the |
| ] | ||
| return common_inputs | ||
|
|
||
| def _generate_dummy_inputs_for_encoder_and_decoder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed this function from _generate_dummy_inputs_for_sequence_classification_and_question_answering() to something that closer reflects its usage in the other dummy input functions.
As noted earlier, Marian models don't have sequence classification or question answering heads, so this change is aimed at minimizing confusion for those inspecting the source code.
|
Thanks for the reviews @LysandreJik and @michaelbenayoun 🙏 ! I've fixed the docs by rebasing on Will merge once all the test pass :) |
|
The outputs decoding cannot get the correct result. How do you get the translation result |
What does this PR do?
This PR adds support to export MarianMT models in the ONNX format. The underlying logic builds on the awesome refactor / feature enhancement that @michaelbenayoun has implemented in #14358 & #14700 -
we should rebase this branch on(Done)masteronce that PR is merged to simplify the diff in this PR.Currently, this PR supports ONNX exports for the following "tasks" (i.e. uses):
default,default-with-past=> equivalent to exporting a pretrainedMarianModelseq2seq-lm,seq2seq-lm-with-past=> equivalent to exporting a pretrainedMarianMTModelcausal-lm,causal-lm-with-past=> equivalent to exporting a pretrainedMarianForCausalLMNote that in each case, the end user will have to implement their own
generate()method with the ONNX model - see this BART example for what's involved.I've also checked locally that the "slow" tests pass with:
Usage
Here's a quick example to show how this works:
TODO
Closes #13823, #13854
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.