diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index 3a227844dbe6..2bf3171d0df8 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -167,3 +167,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("token_type_ids", {0: "batch", 1: "sequence"}), ] ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 5190cbf9dee7..86ca38a61d77 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -15,12 +15,10 @@ """ BART model configuration """ import warnings from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Mapping -from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...file_utils import TensorType, is_torch_available -from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -182,174 +180,30 @@ def __init__( ) -class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): +class BartOnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - if self.task in ["default", "seq2seq-lm"]: - common_inputs = OrderedDict( - [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), - ] - ) - - if self.use_past: - common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} - else: - common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) - if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - elif self.task == "causal-lm": - # TODO: figure this case out. - common_inputs = OrderedDict( + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), ] ) - if self.use_past: - num_encoder_layers, _ = self.num_layers - for i in range(num_encoder_layers): - common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} else: - common_inputs = OrderedDict( + return OrderedDict( [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), - ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), - ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), ] ) - - return common_inputs - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - if self.task in ["default", "seq2seq-lm"]: - common_outputs = super().outputs - else: - common_outputs = super(OnnxConfigWithPast, self).outputs - if self.use_past: - num_encoder_layers, _ = self.num_layers - for i in range(num_encoder_layers): - common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} - return common_outputs - - def generate_dummy_inputs( - self, - tokenizer: PreTrainedTokenizer, - batch_size: int = -1, - seq_length: int = -1, - is_pair: bool = False, - framework: Optional[TensorType] = None, - ) -> Mapping[str, Any]: - - if self.task in ["default", "seq2seq-lm"]: - encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - # Generate decoder inputs - decoder_seq_length = seq_length if not self.use_past else 1 - decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, decoder_seq_length, is_pair, framework - ) - decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} - common_inputs = dict(**encoder_inputs, **decoder_inputs) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - batch, encoder_seq_length = common_inputs["input_ids"].shape - decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] - num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads - encoder_shape = ( - batch, - num_encoder_attention_heads, - encoder_seq_length, - self._config.hidden_size // num_encoder_attention_heads, - ) - decoder_past_length = decoder_seq_length + 3 - decoder_shape = ( - batch, - num_decoder_attention_heads, - decoder_past_length, - self._config.hidden_size // num_decoder_attention_heads, - ) - - common_inputs["decoder_attention_mask"] = torch.cat( - [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 - ) - - common_inputs["past_key_values"] = [] - # If the number of encoder and decoder layers are present in the model configuration, both are considered - num_encoder_layers, num_decoder_layers = self.num_layers - min_num_layers = min(num_encoder_layers, num_decoder_layers) - max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers - remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" - - for _ in range(min_num_layers): - common_inputs["past_key_values"].append( - ( - torch.zeros(decoder_shape), - torch.zeros(decoder_shape), - torch.zeros(encoder_shape), - torch.zeros(encoder_shape), - ) - ) - - # TODO: test this. - shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape - for _ in range(min_num_layers, max_num_layers): - common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) - - elif self.task == "causal-lm": - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - - batch, seqlen = common_inputs["input_ids"].shape - # Not using the same length for past_key_values - past_key_values_length = seqlen + 2 - num_encoder_layers, _ = self.num_layers - num_encoder_attention_heads, _ = self.num_attention_heads - past_shape = ( - batch, - num_encoder_attention_heads, - past_key_values_length, - self._config.hidden_size // num_encoder_attention_heads, - ) - - common_inputs["attention_mask"] = torch.cat( - [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 - ) - common_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) - ] - else: - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - return common_inputs - - def _flatten_past_key_values_(self, flattened_output, name, idx, t): - if self.task in ["default", "seq2seq-lm"]: - flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) - else: - flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( - flattened_output, name, idx, t - ) diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 8b5d5b5e262d..861cdfbc8ea6 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -169,3 +169,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("token_type_ids", {0: "batch", 1: "sequence"}), ] ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/distilbert/configuration_distilbert.py b/src/transformers/models/distilbert/configuration_distilbert.py index bbf7be789d30..733714e721c9 100644 --- a/src/transformers/models/distilbert/configuration_distilbert.py +++ b/src/transformers/models/distilbert/configuration_distilbert.py @@ -142,3 +142,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("attention_mask", {0: "batch", 1: "sequence"}), ] ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})]) diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index c19593c81aff..be4f8df0a8fb 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -15,12 +15,12 @@ # limitations under the License. """ OpenAI GPT-2 configuration """ from collections import OrderedDict -from typing import Any, List, Mapping, Optional +from typing import Any, Mapping, Optional from transformers import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxConfigWithPast, PatchingSpec +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -194,36 +194,29 @@ def __init__( class GPT2OnnxConfig(OnnxConfigWithPast): - def __init__( - self, - config: PretrainedConfig, - task: str = "default", - patching_specs: List[PatchingSpec] = None, - use_past: bool = False, - ): - super().__init__(config, task=task, patching_specs=patching_specs) - if not getattr(self._config, "pad_token_id", None): - # TODO: how to do that better? - self._config.pad_token_id = 0 - @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + common_inputs = OrderedDict({"input_ids": {0: "batch"}}) if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + for i in range(self._config.n_layer * 2): + common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"} + + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} return common_inputs @property - def num_layers(self) -> int: - return self._config.n_layer + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}) + if self.use_past: + for i in range(self._config.n_layer * 2): + common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"} - @property - def num_attention_heads(self) -> int: - return self._config.n_head + return common_outputs + + return common_outputs def generate_dummy_inputs( self, @@ -233,9 +226,7 @@ def generate_dummy_inputs( is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) + common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) # We need to order the input in the way they appears in the forward() ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) @@ -247,27 +238,14 @@ def generate_dummy_inputs( else: import torch - batch, seqlen = common_inputs["input_ids"].shape - # Not using the same length for past_key_values - past_key_values_length = seqlen + 2 - past_shape = ( - batch, - self.num_attention_heads, - past_key_values_length, - self._config.hidden_size // self.num_attention_heads, - ) + batch = common_inputs["input_ids"].shape[0] ordered_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ( + torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), + torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), + ) + for _ in range(self._config.n_layer) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] - if self.use_past: - ordered_inputs["attention_mask"] = torch.cat( - [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 - ) - return ordered_inputs - - @property - def default_onnx_opset(self) -> int: - return 13 diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index 45118dbf8958..959d0bc7def7 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -15,7 +15,7 @@ """ GPT Neo model configuration """ from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Any, Dict, Iterable, Mapping, Optional from ... import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig @@ -212,7 +212,10 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast): def inputs(self) -> Mapping[str, Mapping[int, str]]: common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") + for i in range(self._config.num_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence"} + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} else: common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} @@ -220,8 +223,16 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: return common_inputs @property - def num_attention_heads(self) -> int: - return self._config.num_heads + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + if self.use_past: + for i in range(self._config.num_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + + return common_outputs + + return common_outputs def generate_dummy_inputs( self, @@ -231,10 +242,7 @@ def generate_dummy_inputs( is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: - - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) + common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) # We need to order the input in the way they appears in the forward() ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) @@ -246,27 +254,28 @@ def generate_dummy_inputs( else: import torch - batch, seqlen = common_inputs["input_ids"].shape - # Not using the same length for past_key_values - past_key_values_length = seqlen + 2 - past_shape = ( - batch, - self.num_attention_heads, - past_key_values_length, - self._config.hidden_size // self.num_attention_heads, - ) + batch = common_inputs["input_ids"].shape[0] + past_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads) ordered_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self._config.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] if self.use_past: ordered_inputs["attention_mask"] = torch.cat( - [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + [ordered_inputs["attention_mask"], torch.ones(batch, 1)], dim=1 ) return ordered_inputs - @property - def default_onnx_opset(self) -> int: - return 13 + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if name in ["present", "past_key_values"]: + flatten_output = {} + for idx, t in enumerate(field): + flatten_output[f"{name}.{idx}.key"] = t[0] + flatten_output[f"{name}.{idx}.value"] = t[1] + + return flatten_output + + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 39083a093c38..d1eb27c0e808 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -14,12 +14,11 @@ # limitations under the License. """ MBART model configuration """ from collections import OrderedDict -from typing import Any, Mapping, Optional +from typing import Mapping + +from transformers.onnx import OnnxConfigWithPast -from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...file_utils import TensorType, is_torch_available -from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from ...utils import logging @@ -166,175 +165,30 @@ def __init__( ) -# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart -class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast): +class MBartOnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - if self.task in ["default", "seq2seq-lm"]: - common_inputs = OrderedDict( - [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), - ] - ) - - if self.use_past: - common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} - else: - common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) - if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") - elif self.task == "causal-lm": - # TODO: figure this case out. - common_inputs = OrderedDict( + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), ] ) - if self.use_past: - num_encoder_layers, _ = self.num_layers - for i in range(num_encoder_layers): - common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} else: - common_inputs = OrderedDict( + return OrderedDict( [ - ("input_ids", {0: "batch", 1: "encoder_sequence"}), - ("attention_mask", {0: "batch", 1: "encoder_sequence"}), - ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), - ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), ] ) - - return common_inputs - - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - if self.task in ["default", "seq2seq-lm"]: - common_outputs = super().outputs - else: - common_outputs = super(OnnxConfigWithPast, self).outputs - if self.use_past: - num_encoder_layers, _ = self.num_layers - for i in range(num_encoder_layers): - common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} - return common_outputs - - def generate_dummy_inputs( - self, - tokenizer: PreTrainedTokenizer, - batch_size: int = -1, - seq_length: int = -1, - is_pair: bool = False, - framework: Optional[TensorType] = None, - ) -> Mapping[str, Any]: - - if self.task in ["default", "seq2seq-lm"]: - encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - # Generate decoder inputs - decoder_seq_length = seq_length if not self.use_past else 1 - decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, decoder_seq_length, is_pair, framework - ) - decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} - common_inputs = dict(**encoder_inputs, **decoder_inputs) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - batch, encoder_seq_length = common_inputs["input_ids"].shape - decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] - num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads - encoder_shape = ( - batch, - num_encoder_attention_heads, - encoder_seq_length, - self._config.hidden_size // num_encoder_attention_heads, - ) - decoder_past_length = decoder_seq_length + 3 - decoder_shape = ( - batch, - num_decoder_attention_heads, - decoder_past_length, - self._config.hidden_size // num_decoder_attention_heads, - ) - - common_inputs["decoder_attention_mask"] = torch.cat( - [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 - ) - - common_inputs["past_key_values"] = [] - # If the number of encoder and decoder layers are present in the model configuration, both are considered - num_encoder_layers, num_decoder_layers = self.num_layers - min_num_layers = min(num_encoder_layers, num_decoder_layers) - max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers - remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" - - for _ in range(min_num_layers): - common_inputs["past_key_values"].append( - ( - torch.zeros(decoder_shape), - torch.zeros(decoder_shape), - torch.zeros(encoder_shape), - torch.zeros(encoder_shape), - ) - ) - - # TODO: test this. - shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape - for _ in range(min_num_layers, max_num_layers): - common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) - - elif self.task == "causal-lm": - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - - batch, seqlen = common_inputs["input_ids"].shape - # Not using the same length for past_key_values - past_key_values_length = seqlen + 2 - num_encoder_layers, _ = self.num_layers - num_encoder_attention_heads, _ = self.num_attention_heads - past_shape = ( - batch, - num_encoder_attention_heads, - past_key_values_length, - self._config.hidden_size // num_encoder_attention_heads, - ) - - common_inputs["attention_mask"] = torch.cat( - [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 - ) - common_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) - ] - else: - common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - return common_inputs - - def _flatten_past_key_values_(self, flattened_output, name, idx, t): - if self.task in ["default", "seq2seq-lm"]: - flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) - else: - flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( - flattened_output, name, idx, t - ) diff --git a/src/transformers/models/roberta/configuration_roberta.py b/src/transformers/models/roberta/configuration_roberta.py index 1d4540deef31..25fc855bd427 100644 --- a/src/transformers/models/roberta/configuration_roberta.py +++ b/src/transformers/models/roberta/configuration_roberta.py @@ -76,3 +76,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("attention_mask", {0: "batch", 1: "sequence"}), ] ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 36b1344bdf2d..bb16a5fb0f50 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ T5 model configuration """ -from typing import Mapping +from collections import OrderedDict +from typing import Any, Dict, Iterable, Mapping, Optional -# from ... import is_torch_available +from transformers import PreTrainedTokenizer, TensorType + +from ... import is_torch_available from ...configuration_utils import PretrainedConfig -from ...onnx import OnnxSeq2SeqConfigWithPast +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -122,26 +125,101 @@ def __init__( ) -class T5OnnxConfig(OnnxSeq2SeqConfigWithPast): +class T5OnnxConfig(OnnxConfigWithPast): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: - common_inputs = { - "input_ids": {0: "batch", 1: "encoder_sequence"}, - "attention_mask": {0: "batch", 1: "encoder_sequence"}, - } - if self.use_past: - common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" - common_inputs["decoder_input_ids"] = {0: "batch"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} - else: - common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} - common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch"}), + ("decoder_attention_mask", {0: "batch"}), + ] + ) if self.use_past: - self.fill_with_past_key_values_(common_inputs, direction="inputs") + for i in range(0, self._config.num_layers): + common_inputs[f"past_key_values.{i}.decoder.key"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.decoder.value"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.encoder.key"] = {0: "batch", 2: "past_sequence"} + common_inputs[f"past_key_values.{i}.encoder.value"] = {0: "batch", 2: "past_sequence"} return common_inputs @property - def default_onnx_opset(self) -> int: - return 13 + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + + if "last_hidden_state" in common_outputs: + common_outputs["last_hidden_state"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + for i in range(self._config.num_layers): + common_outputs[f"present.{i}.decoder.key"] = {0: "batch", 2: "decoder_sequence"} + common_outputs[f"present.{i}.decoder.value"] = {0: "batch", 2: "decoder_sequence"} + common_outputs[f"present.{i}.encoder.key"] = {0: "batch", 2: "encoder_sequence"} + common_outputs[f"present.{i}.encoder.value"] = {0: "batch", 2: "encoder_sequence"} + + if self.task == "default": + common_outputs["encoder_last_hidden_state"] = {0: "batch", 2: "encoder_sequence"} + + return common_outputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + + # Generate encoder inputs + encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + + # Generate decoder inputs + decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + + ordered_inputs = dict(**encoder_inputs, **decoder_inputs) + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch = encoder_inputs["input_ids"].shape[0] + encoder_seq_length = encoder_inputs["input_ids"].shape[1] + encoder_shape = ( + batch, + self._config.num_heads, + encoder_seq_length, + self._config.hidden_size // self._config.num_heads, + ) + decoder_shape = (batch, self._config.num_heads, 1, self._config.hidden_size // self._config.num_heads) + + ordered_inputs["past_key_values"] = [] + for _ in range(self._config.num_layers): + ordered_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + + return ordered_inputs + + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + if name in ["present", "past_key_values"]: + flatten_output = {} + for idx, t in enumerate(field): + flatten_output[f"{name}.{idx}.decoder.key"] = t[0] + flatten_output[f"{name}.{idx}.decoder.value"] = t[1] + flatten_output[f"{name}.{idx}.encoder.key"] = t[2] + flatten_output[f"{name}.{idx}.encoder.value"] = t[3] + + return flatten_output + + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py index ebf9c7abcfab..9300bfcc79e9 100644 --- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py @@ -53,3 +53,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("attention_mask", {0: "batch", 1: "sequence"}), ] ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/onnx/__init__.py b/src/transformers/onnx/__init__.py index d08d7b2711cb..a80567e202b0 100644 --- a/src/transformers/onnx/__init__.py +++ b/src/transformers/onnx/__init__.py @@ -13,12 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config import ( - EXTERNAL_DATA_FORMAT_SIZE_LIMIT, - OnnxConfig, - OnnxConfigWithPast, - OnnxSeq2SeqConfigWithPast, - PatchingSpec, -) +from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast, PatchingSpec from .convert import export, validate_model_outputs from .utils import ParameterFormat, compute_serialized_parameters_size diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py index eb5d2773b0da..be7244233166 100644 --- a/src/transformers/onnx/__main__.py +++ b/src/transformers/onnx/__main__.py @@ -32,10 +32,10 @@ def main(): help="Export the model with some additional feature.", ) parser.add_argument( - "--opset", type=int, default=None, help="ONNX opset version to export the model with (default 12)." + "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." ) parser.add_argument( - "--atol", type=float, default=None, help="Absolute difference tolerence when validating the model." + "--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model." ) parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") @@ -53,9 +53,6 @@ def main(): onnx_config = model_onnx_config(model.config) # Ensure the requested opset is sufficient - if args.opset is None: - args.opset = onnx_config.default_onnx_opset - if args.opset < onnx_config.default_onnx_opset: raise ValueError( f"Opset {args.opset} is not sufficient to export {model_kind}. " @@ -64,9 +61,6 @@ def main(): onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) - if args.atol is None: - args.atol = onnx_config.atol_for_validation - validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 8fddd2ada691..8e9e1575b1e7 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -14,9 +14,9 @@ import dataclasses from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional -from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType, is_torch_available +from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size @@ -58,7 +58,6 @@ class OnnxConfig(ABC): _TASKS_TO_COMMON_OUTPUTS = { "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), - "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), "sequence-classification": OrderedDict({"logits": {0: "batch"}}), @@ -120,8 +119,7 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]: Returns: For each output: its name associated to the axes symbolic name and the axis position within the tensor """ - common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task] - return common_outputs + return self._TASKS_TO_COMMON_OUTPUTS[self.task] @property def values_override(self) -> Optional[Mapping[str, Any]]: @@ -167,16 +165,6 @@ def default_onnx_opset(self) -> int: """ return DEFAULT_ONNX_OPSET - @property - def atol_for_validation(self) -> float: - """ - What absolute tolerance value to use during model conversion validation. - - Returns: - Float absolute tolerance value. - """ - return 1e-5 - @staticmethod def use_external_data_format(num_parameters: int) -> bool: """ @@ -241,8 +229,8 @@ def restore_ops(self): orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op) setattr(spec.o, spec.name, orig_op) - @classmethod - def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> Dict[str, Any]: + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: """ Flatten any potential nested structure expanding the name of the field with the index of the element within the structure. @@ -284,14 +272,6 @@ def with_past(cls, config: PretrainedConfig, task: str = "default") -> "OnnxConf """ return cls(config, task=task, use_past=True) - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = super().outputs - if self.use_past: - self.fill_with_past_key_values_(common_outputs, direction="outputs") - - return common_outputs - @property def values_override(self) -> Optional[Mapping[str, Any]]: if hasattr(self._config, "use_cache"): @@ -299,30 +279,6 @@ def values_override(self) -> Optional[Mapping[str, Any]]: return None - @property - def num_layers(self) -> int: - """ - The number of layers attribute retrieved from the model config. Override this for model configs where the - number of layers attribute is not called `num_layers`. - """ - if not hasattr(self._config, "num_layers"): - raise AttributeError( - "could not find the number of layers attribute in the model configuration, override the num_layers property of the model OnnxConfig to solve this" - ) - return self._config.num_layers - - @property - def num_attention_heads(self) -> int: - """ - The number of attention heads attribute retrieved from the model config. Override this for model configs where - the number of attention heads attribute is not called `num_attention_heads`. - """ - if not hasattr(self._config, "num_attention_heads"): - raise AttributeError( - "could not find the number of attention heads attribute in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this" - ) - return self._config.num_attention_heads - def generate_dummy_inputs( self, tokenizer: PreTrainedTokenizer, @@ -331,217 +287,32 @@ def generate_dummy_inputs( is_pair: bool = False, framework: Optional[TensorType] = None, ) -> Mapping[str, Any]: + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0 + ) - # TODO: should we set seq_length = 1 when self.use_past = True? - common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - - batch, seqlen = common_inputs["input_ids"].shape - # Not using the same length for past_key_values - past_key_values_length = seqlen + 2 - shape = ( - batch, - self.num_attention_heads, - past_key_values_length, - self._config.hidden_size // self.num_attention_heads, - ) - - if "attention_mask" in common_inputs: - common_inputs["attention_mask"] = torch.cat( - [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 - ) - - common_inputs["past_key_values"] = [] - for _ in range(self.num_layers): - common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) - - return common_inputs - - def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): - """ - Fill the input_or_ouputs mapping with past_key_values dynamic axes considering. - - Args: - inputs_or_outputs: The mapping to fill. - direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the - output mapping, this is important for axes naming. - - """ - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) - name = "past_key_values" if direction == "inputs" else "present" - for i in range(self.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} - inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + # When use_past the caching mechanism requires inputs to be only 1 single token + fixed_sequence_length = 1 if self.use_past else self.default_sequence_length + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add + ) - def _flatten_past_key_values_(self, flattened_output, name, idx, t): - flattened_output[f"{name}.{idx}.key"] = t[0] - flattened_output[f"{name}.{idx}.value"] = t[1] + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) - def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: - flattened_output = {} + @staticmethod + def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: if name in ["present", "past_key_values"]: + flatten_output = {} for idx, t in enumerate(field): - self._flatten_past_key_values_(flattened_output, name, idx, t) - else: - flattened_output = super().flatten_output_collection_property(name, field) - - return flattened_output - - -class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast): - @property - def outputs(self) -> Mapping[str, Mapping[int, str]]: - common_outputs = self._TASKS_TO_COMMON_OUTPUTS[self.task] - # Renaming the outputs axes properly. - for name, axes_names in common_outputs.items(): - sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence" - for axis_idx, name in axes_names.items(): - if "sequence" in name: - axes_names[axis_idx] = sequence_name - # We reset the value as the order in common_outputs (OrderedDict) is lost otherwise - else: - axes_names[axis_idx] = name - if self.use_past: - self.fill_with_past_key_values_(common_outputs, direction="outputs") - - return common_outputs - - @property - def num_layers(self) -> Tuple[int]: - try: - num_layers = super().num_layers - num_layers = (num_layers, num_layers) - except AttributeError: - if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"): - num_layers = (self._config.encoder_layers, self._config.decoder_layers) - else: - raise AttributeError( - "could not find the number of encoder and decoder layers attributes in the model configuration, override the num_layers property of the model OnnxConfig to solve this" - ) - - return num_layers + flatten_output[f"{name}.{idx}.key"] = t[0] + flatten_output[f"{name}.{idx}.value"] = t[1] - @property - def num_attention_heads(self) -> Tuple[int]: - try: - num_attention_heads = super().num_attention_heads - num_attention_heads = (num_attention_heads, num_attention_heads) - except AttributeError: - if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"): - num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads) - else: - raise AttributeError( - "could not find the number of attention heads for the encoder and the decoder attributes in the model configuration, override the num_attention_heads property of the model OnnxConfig to solve this" - ) - return num_attention_heads - - def generate_dummy_inputs( - self, - tokenizer: PreTrainedTokenizer, - batch_size: int = -1, - seq_length: int = -1, - is_pair: bool = False, - framework: Optional[TensorType] = None, - ) -> Mapping[str, Any]: - - encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, seq_length, is_pair, framework - ) - - # Generate decoder inputs - decoder_seq_length = seq_length if not self.use_past else 1 - decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( - tokenizer, batch_size, decoder_seq_length, is_pair, framework - ) - decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} - common_inputs = dict(**encoder_inputs, **decoder_inputs) - - if self.use_past: - if not is_torch_available(): - raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") - else: - import torch - batch = common_inputs["input_ids"].shape[0] - encoder_seq_length = common_inputs["input_ids"].shape[1] - decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] - num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads - encoder_shape = ( - batch, - num_encoder_attention_heads, - encoder_seq_length, - self._config.hidden_size // num_encoder_attention_heads, - ) - decoder_shape = ( - batch, - num_decoder_attention_heads, - # Not using the same length for past_key_values - decoder_seq_length + 3, - self._config.hidden_size // num_decoder_attention_heads, - ) + return flatten_output - common_inputs["past_key_values"] = [] - # If the number of encoder and decoder layers are present in the model configuration, both are considered - num_encoder_layers, num_decoder_layers = self.num_layers - min_num_layers = min(num_encoder_layers, num_decoder_layers) - max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers - remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" - - for _ in range(min_num_layers): - # For encoder-decoder models, past_key_values contains pre-computed values for both the encoder and the - # decoder layers, hence a tuple of 4 tensors instead of 2 - common_inputs["past_key_values"].append( - ( - torch.zeros(decoder_shape), - torch.zeros(decoder_shape), - torch.zeros(encoder_shape), - torch.zeros(encoder_shape), - ) - ) - - # TODO: test this. - shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape - for _ in range(min_num_layers, max_num_layers): - common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) - - return common_inputs - - def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str): - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - - name = "past_key_values" if direction == "inputs" else "present" - - # If the number of encoder and decoder layers are present in the model configuration, both are considered - num_encoder_layers, num_decoder_layers = self.num_layers - min_num_layers = min(num_encoder_layers, num_decoder_layers) - max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers - remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" - - encoder_sequence = "past_encoder_sequence" - decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence" - - for i in range(min_num_layers): - inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence} - inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence} - inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence} - inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence} - - for i in range(min_num_layers, max_num_layers): - if remaining_side_name == "encoder": - axes_info = {0: "batch", 2: encoder_sequence} - else: - axes_info = {0: "batch", 2: decoder_sequence} - inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info - - def _flatten_past_key_values_(self, flattened_output, name, idx, t): - flattened_output[f"{name}.{idx}.decoder.key"] = t[0] - flattened_output[f"{name}.{idx}.decoder.value"] = t[1] - flattened_output[f"{name}.{idx}.encoder.key"] = t[2] - flattened_output[f"{name}.{idx}.encoder.value"] = t[3] + return super().flatten_output_collection_property(name, field) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 041e21832a06..313a7fd2e621 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -191,7 +191,7 @@ def validate_model_outputs( f"{onnx_outputs_set.difference(ref_outputs_set)}" ) else: - logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set})") + logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}") # Check the shape and values match for name, ort_value in zip(onnx_named_outputs, onnx_outputs): diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index b44eb46a35e4..d685af4cf771 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,7 +1,7 @@ from functools import partial, reduce -from typing import Callable, Dict, Optional, Tuple, Type +from typing import Callable, Tuple -from .. import PretrainedConfig, is_torch_available +from .. import is_torch_available from ..models.albert import AlbertOnnxConfig from ..models.bart import BartOnnxConfig from ..models.bert import BertOnnxConfig @@ -15,7 +15,6 @@ from ..models.roberta import RobertaOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig -from .config import OnnxConfig if is_torch_available(): @@ -23,7 +22,6 @@ from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, - AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, @@ -32,19 +30,8 @@ ) -def supported_features_mapping( - *supported_features: str, onnx_config_cls: Type[OnnxConfig] = None -) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]: - """ - Generate the mapping between supported the features and their corresponding OnnxConfig for a given model. - - Args: - *supported_features: The names of the supported features. - onnx_config_cls: The OnnxConfig class corresponding to the model. - - Returns: - The dictionary mapping a feature to an OnnxConfig constructor. - """ +def supported_features_mapping(*supported_features, onnx_config_cls=None): + """Generates the mapping between supported features and their corresponding OnnxConfig.""" if onnx_config_cls is None: raise ValueError("A OnnxConfig class must be provided") @@ -62,7 +49,6 @@ def supported_features_mapping( class FeaturesManager: _TASKS_TO_AUTOMODELS = { "default": AutoModel, - "masked-lm": AutoModelForMaskedLM, "causal-lm": AutoModelForCausalLM, "seq2seq-lm": AutoModelForSeq2SeqLM, "sequence-classification": AutoModelForSequenceClassification, @@ -72,110 +58,27 @@ class FeaturesManager: } # Set of model topologies we support associated to the features supported by each topology and the factory - _SUPPORTED_MODEL_TYPE = { - "albert": supported_features_mapping( - "default", - "masked-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=AlbertOnnxConfig, - ), - "bart": supported_features_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", - "sequence-classification", - "question-answering", - onnx_config_cls=BartOnnxConfig, - ), - "mbart": supported_features_mapping( - "default", - "default-with-past", - "causal-lm", - "causal-lm-with-past", - "seq2seq-lm", - "seq2seq-lm-with-past", - "sequence-classification", - "question-answering", - onnx_config_cls=MBartOnnxConfig, - ), - "bert": supported_features_mapping( - "default", - "masked-lm", - "causal-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=BertOnnxConfig, - ), + _SUPPORTED_MODEL_KIND = { + "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), + "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), + "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), + "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "camembert": supported_features_mapping( "default", - "masked-lm", "causal-lm", "sequence-classification", - # "multiple-choice", "token-classification", "question-answering", onnx_config_cls=CamembertOnnxConfig, ), - "distilbert": supported_features_mapping( - "default", - "masked-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=DistilBertOnnxConfig, - ), - "longformer": supported_features_mapping( - "default", - "masked-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=LongformerOnnxConfig, - ), - "roberta": supported_features_mapping( - "default", - "masked-lm", - "causal-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=RobertaOnnxConfig, - ), + "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), + "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), + "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), + "roberta": supported_features_mapping("default", onnx_config_cls=RobertaOnnxConfig), "t5": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig ), - "xlm-roberta": supported_features_mapping( - "default", - "masked-lm", - "causal-lm", - "sequence-classification", - # "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=XLMRobertaOnnxConfig, - ), - "gpt2": supported_features_mapping( - "default", - "causal-lm", - "sequence-classification", - "token-classification", - "default-with-past", - "causal-lm-with-past", - "sequence-classification-with-past", - "token-classification-with-past", - onnx_config_cls=GPT2OnnxConfig, - ), + "xlm-roberta": supported_features_mapping("default", onnx_config_cls=XLMRobertaOnnxConfig), "gpt-neo": supported_features_mapping( "default", "causal-lm", @@ -194,46 +97,23 @@ class FeaturesManager: ), } - AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) - - @staticmethod - def get_supported_features_for_model_type( - model_type: str, model_name: Optional[str] = None - ) -> Dict[str, Callable[[PretrainedConfig, str], OnnxConfig]]: - """ - Try to retrieve the feature -> OnnxConfig constructor map from the model type. - - Args: - model_type: The model type to retrieve the supported features for. - model_name: The name attribute of the model object, only used for the exception message. - - Returns: - The dictionary mapping each feature to a corresponding OnnxConfig constructor. - """ - model_type = model_type.lower() - if model_type not in FeaturesManager._SUPPORTED_MODEL_TYPE: - model_type_and_model_name = f"{model_type} ({model_name})" if model_name else model_type - raise KeyError( - f"{model_type_and_model_name} is not supported yet. " - f"Only {list(FeaturesManager._SUPPORTED_MODEL_TYPE.keys())} are supported. " - f"If you want to support {model_type} please propose a PR or open up an issue." - ) - return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type] + AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_KIND.values()))) @staticmethod def feature_to_task(feature: str) -> str: return feature.replace("-with-past", "") @staticmethod - def get_model_class_for_feature(feature: str) -> Type: + def get_model_from_feature(feature: str, model: str): """ - Attempt to retrieve an AutoModel class from a feature name. + Attempt to retrieve a model from a model's name and the feature to be enabled. Args: - feature: The feature required. + feature: The feature required + model: The name of the model to export Returns: - The AutoModel class corresponding to the feature. + """ task = FeaturesManager.feature_to_task(feature) if task not in FeaturesManager._TASKS_TO_AUTOMODELS: @@ -241,43 +121,38 @@ def get_model_class_for_feature(feature: str) -> Type: f"Unknown task: {feature}. " f"Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}" ) - return FeaturesManager._TASKS_TO_AUTOMODELS[task] - def get_model_from_feature(feature: str, model: str) -> PreTrainedModel: - """ - Attempt to retrieve a model from a model's name and the feature to be enabled. - - Args: - feature: The feature required. - model: The name of the model to export. - - Returns: - The instance of the model. - - """ - model_class = FeaturesManager.get_model_class_for_feature(feature) - return model_class.from_pretrained(model) + return FeaturesManager._TASKS_TO_AUTOMODELS[task].from_pretrained(model) @staticmethod def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]: """ - Check whether or not the model has the requested features. + Check whether or not the model has the requested features Args: - model: The model to export. - feature: The name of the feature to check if it is available. + model: The model to export + feature: The name of the feature to check if it is available Returns: - (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties. + (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties """ model_type = model.config.model_type.replace("_", "-") model_name = getattr(model, "name", "") - model_features = FeaturesManager.get_supported_features_for_model_type(model_type, model_name=model_name) + model_name = f"({model_name})" if model_name else "" + if model_type not in FeaturesManager._SUPPORTED_MODEL_KIND: + raise KeyError( + f"{model.config.model_type} ({model_name}) is not supported yet. " + f"Only {FeaturesManager._SUPPORTED_MODEL_KIND} are supported. " + f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." + ) + + # Look for the features + model_features = FeaturesManager._SUPPORTED_MODEL_KIND[model_type] if feature not in model_features: raise ValueError( f"{model.config.model_type} doesn't support feature {feature}. " - f"Supported values are: {model_features}" + f"Supported values are: {list(model_features.keys())}" ) - return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature] + return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_KIND[model_type][feature] diff --git a/src/transformers/test.onnx/model.onnx b/src/transformers/test.onnx/model.onnx deleted file mode 100644 index ed5f887d376a..000000000000 Binary files a/src/transformers/test.onnx/model.onnx and /dev/null differ diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index a62d8386b5dd..861b781442e9 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -3,8 +3,33 @@ from unittest import TestCase from unittest.mock import patch -from parameterized import parameterized -from transformers import AutoConfig, AutoTokenizer +from transformers import ( # LongformerConfig,; T5Config, + AlbertConfig, + AutoTokenizer, + BartConfig, + DistilBertConfig, + GPT2Config, + GPTNeoConfig, + LayoutLMConfig, + MBartConfig, + RobertaConfig, + XLMRobertaConfig, + is_torch_available, +) +from transformers.models.albert import AlbertOnnxConfig +from transformers.models.bart import BartOnnxConfig +from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig +from transformers.models.distilbert import DistilBertOnnxConfig + +# from transformers.models.longformer import LongformerOnnxConfig +from transformers.models.gpt2 import GPT2OnnxConfig +from transformers.models.gpt_neo import GPTNeoOnnxConfig +from transformers.models.layoutlm import LayoutLMOnnxConfig +from transformers.models.mbart import MBartOnnxConfig +from transformers.models.roberta import RobertaOnnxConfig + +# from transformers.models.t5 import T5OnnxConfig +from transformers.models.xlm_roberta import XLMRobertaOnnxConfig from transformers.onnx import ( EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, @@ -12,8 +37,7 @@ export, validate_model_outputs, ) -from transformers.onnx.config import OnnxConfigWithPast -from transformers.onnx.features import FeaturesManager +from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size from transformers.testing_utils import require_onnx, require_torch, slow @@ -115,12 +139,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase): Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) """ - SUPPORTED_WITH_PAST_CONFIGS = {} - # SUPPORTED_WITH_PAST_CONFIGS = { - # ("BART", BartConfig), - # ("GPT2", GPT2Config), - # # ("T5", T5Config) - # } + SUPPORTED_WITH_PAST_CONFIGS = { + ("BART", BartConfig), + ("GPT2", GPT2Config), + # ("T5", T5Config) + } @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) def test_use_past(self): @@ -164,37 +187,40 @@ def test_values_override(self): ) -PYTORCH_EXPORT_MODELS = { - ("albert", "hf-internal-testing/tiny-albert"), - ("bert", "bert-base-cased"), - ("camembert", "camembert-base"), - ("distilbert", "distilbert-base-cased"), - # ("longFormer", "longformer-base-4096"), - ("roberta", "roberta-base"), - ("xlm-roberta", "xlm-roberta-base"), - ("layoutlm", "microsoft/layoutlm-base-uncased"), -} - -PYTORCH_EXPORT_WITH_PAST_MODELS = { - ("gpt2", "gpt2"), - ("gpt-neo", "EleutherAI/gpt-neo-125M"), -} - -PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { - ("bart", "facebook/bart-base"), - ("mbart", "sshleifer/tiny-mbart"), - ("t5", "t5-small"), -} - - -def _get_models_to_test(export_models_list): - models_to_test = [] - for (name, model) in export_models_list: - for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type( - name - ).items(): - models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor)) - return models_to_test +if is_torch_available(): + from transformers import ( # T5Model, + AlbertModel, + BartModel, + BertModel, + DistilBertModel, + GPT2Model, + GPTNeoModel, + LayoutLMModel, + MBartModel, + RobertaModel, + XLMRobertaModel, + ) + + PYTORCH_EXPORT_DEFAULT_MODELS = { + ("ALBERT", "hf-internal-testing/tiny-albert", AlbertModel, AlbertConfig, AlbertOnnxConfig), + ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), + ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), + ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), + ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), + ("GPT-Neo", "EleutherAI/gpt-neo-125M", GPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig), + # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), + ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), + ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), + ("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig), + ("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig), + # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), + } + + PYTORCH_EXPORT_WITH_PAST_MODELS = { + # ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), + # ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), + # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig) + } class OnnxExportTestCaseV2(TestCase): @@ -202,52 +228,52 @@ class OnnxExportTestCaseV2(TestCase): Integration tests ensuring supported models are correctly exported """ - def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): + @slow + @require_torch + def test_pytorch_export_default(self): from transformers.onnx import export - tokenizer = AutoTokenizer.from_pretrained(model_name) - config = AutoConfig.from_pretrained(model_name) - - # Useful for causal lm models that do not use pad tokens. - if not getattr(config, "pad_token_id", None): - config.pad_token_id = tokenizer.eos_token_id - - model_class = FeaturesManager.get_model_class_for_feature(feature) - model = model_class.from_config(config) - onnx_config = onnx_config_class_constructor(model.config) - - with NamedTemporaryFile("w") as output: - onnx_inputs, onnx_outputs = export( - tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) - ) - try: - validate_model_outputs( - onnx_config, - tokenizer, - model, - Path(output.name), - onnx_outputs, - onnx_config.atol_for_validation, - ) - except ValueError as ve: - self.fail(f"{name}, {feature} -> {ve}") + for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: + with self.subTest(name): + self.assertTrue(hasattr(onnx_config_class, "from_model_config")) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) - @slow - @require_torch - def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + tokenizer = AutoTokenizer.from_pretrained(model) + model = model_class(config_class.from_pretrained(model)) + onnx_config = onnx_config_class.from_model_config(model.config) - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) - @slow - @require_torch - def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + with NamedTemporaryFile("w") as output: + onnx_inputs, onnx_outputs = export( + tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name) + ) + + try: + validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5) + except ValueError as ve: + self.fail(f"{name} -> {ve}") - @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS)) @slow @require_torch - def test_pytorch_export_seq2seq_with_past( - self, test_name, name, model_name, feature, onnx_config_class_constructor - ): - self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor) + def test_pytorch_export_with_past(self): + from transformers.onnx import export + + for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS: + with self.subTest(name): + self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()") + + tokenizer = AutoTokenizer.from_pretrained(model) + model = model_class(config_class()) + onnx_config = onnx_config_class.with_past(model.config) + + self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.") + self.assertTrue( + onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()" + ) + + with NamedTemporaryFile("w") as output: + output = Path(output.name) + onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output) + + try: + validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5) + except ValueError as ve: + self.fail(f"{name} -> {ve}")