-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[WIP] Add MarianMT to models exportable with ONNX #13854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
michaelbenayoun
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR!
| _SUPPORTED_MODEL_KIND = { | ||
| "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), | ||
| "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), | ||
| "marian": supported_features_mapping("seq2seq-lm", onnx_config_cls=MarianOnnxConfig), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also add the "default" task.
| self._setup_normalizer() | ||
|
|
||
| def num_special_tokens_to_add(self, **unused): | ||
| def num_special_tokens_to_add(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_special_tokens_to_add is called here with a positional argument, that causes an error if the function is defined only with keyword arguments **unused.
|
|
||
| >>> last_hidden_states = outputs.last_hidden_state | ||
| """ | ||
| # different to other models, Marian automatically creates decoder_input_ids from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LysandreJik @patil-suraj What do you think?
Just saw this comment, will investigate and come back to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, anyway I just copied what BART does here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the deocder_input_ids from input_ids. For Marian users should pass either decoder_input_ids or labels and this is already handled by the MarianMTModel class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the
deocder_input_idsfrominput_ids. For Marian users should pass eitherdecoder_input_idsorlabelsand this is already handled by theMarianMTModelclass.
Thank you @patil-suraj for the reply, but as I said in this comment, transformers.onnx.convert.export() calls MarianMTModel.forward() without passing decoder_input_ids or labels, so how is this supposed to be handled?
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Resolves #13823
@patil-suraj @michaelbenayoun