Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/marian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


_import_structure = {
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig"],
"configuration_marian": ["MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP", "MarianConfig", "MarianOnnxConfig"],
}

if is_sentencepiece_available():
Expand All @@ -49,7 +49,7 @@
if is_flax_available():
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
if TYPE_CHECKING:
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig

if is_sentencepiece_available():
from .tokenization_marian import MarianTokenizer
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/marian/configuration_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Marian model configuration """
from collections import OrderedDict
from typing import Mapping

from transformers.onnx import OnnxConfig

from ...configuration_utils import PretrainedConfig
from ...utils import logging
Expand Down Expand Up @@ -159,3 +163,14 @@ def __init__(
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)


class MarianOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
7 changes: 7 additions & 0 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,13 @@ def forward(

>>> last_hidden_states = outputs.last_hidden_state
"""
# different to other models, Marian automatically creates decoder_input_ids from
Copy link
Member

@michaelbenayoun michaelbenayoun Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik @patil-suraj What do you think?

Just saw this comment, will investigate and come back to you.

Copy link
Contributor Author

@Maxinho96 Maxinho96 Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, anyway I just copied what BART does here

Copy link
Contributor

@patil-suraj patil-suraj Nov 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the deocder_input_ids from input_ids. For Marian users should pass either decoder_input_ids or labels and this is already handled by the MarianMTModel class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be added here. This is done for BART because of the denoising pre-training objective, so only BART and mBART prepare the deocder_input_ids from input_ids. For Marian users should pass either decoder_input_ids or labels and this is already handled by the MarianMTModel class.

Thank you @patil-suraj for the reply, but as I said in this comment, transformers.onnx.convert.export() calls MarianMTModel.forward() without passing decoder_input_ids or labels, so how is this supposed to be handled?

# input_ids if no decoder_input_ids are provided
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __setstate__(self, d: Dict) -> None:
self.current_spm = self.spm_source
self._setup_normalizer()

def num_special_tokens_to_add(self, **unused):
def num_special_tokens_to_add(self, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_special_tokens_to_add is called here with a positional argument, that causes an error if the function is defined only with keyword arguments **unused.

"""Just EOS"""
return 1

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig
Expand Down Expand Up @@ -60,6 +61,7 @@ class FeaturesManager:
_SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"marian": supported_features_mapping("seq2seq-lm", onnx_config_cls=MarianOnnxConfig),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add the "default" task.

"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
Expand Down