Skip to content
4 changes: 0 additions & 4 deletions src/transformers/models/albert/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
184 changes: 165 additions & 19 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
""" BART model configuration """
import warnings
from collections import OrderedDict
from typing import Mapping
from typing import Any, Mapping, Optional

from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...utils import logging


Expand Down Expand Up @@ -180,30 +182,174 @@ def __init__(
)


class BartOnnxConfig(OnnxConfigWithPast):
class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
return OrderedDict(
common_inputs = OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
]
)

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:

if self.task in ["default", "seq2seq-lm"]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)

common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)

common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)

# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))

elif self.task == "causal-lm":
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)

common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
else:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

return common_inputs

def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
4 changes: 0 additions & 4 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,3 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
66 changes: 44 additions & 22 deletions src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
""" OpenAI GPT-2 configuration """
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Any, List, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType, is_torch_available

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging


Expand Down Expand Up @@ -194,29 +194,36 @@ def __init__(


class GPT2OnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}

common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
def num_layers(self) -> int:
return self._config.n_layer

return common_outputs

return common_outputs
@property
def num_attention_heads(self) -> int:
return self._config.n_head

def generate_dummy_inputs(
self,
Expand All @@ -226,7 +233,9 @@ def generate_dummy_inputs(
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
Expand All @@ -238,14 +247,27 @@ def generate_dummy_inputs(
else:
import torch

batch = common_inputs["input_ids"].shape[0]
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
)
for _ in range(self._config.n_layer)
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)

return ordered_inputs

@property
def default_onnx_opset(self) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea to define default operator sets this way :)

return 13
Loading