Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/albert/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
184 changes: 19 additions & 165 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
""" BART model configuration """
import warnings
from collections import OrderedDict
from typing import Any, Mapping, Optional
from typing import Mapping

from ... import PreTrainedTokenizer
from ...configuration_utils import PretrainedConfig
from ...file_utils import TensorType, is_torch_available
from ...onnx import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from ...onnx import OnnxConfigWithPast
from ...utils import logging


Expand Down Expand Up @@ -182,174 +180,30 @@ def __init__(
)


class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
class BartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_inputs = OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
]
)

if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
elif self.task == "causal-lm":
# TODO: figure this case out.
common_inputs = OrderedDict(
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
else:
common_inputs = OrderedDict(
return OrderedDict(
[
("input_ids", {0: "batch", 1: "encoder_sequence"}),
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)

return common_inputs

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task in ["default", "seq2seq-lm"]:
common_outputs = super().outputs
else:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
num_encoder_layers, _ = self.num_layers
for i in range(num_encoder_layers):
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
return common_outputs

def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:

if self.task in ["default", "seq2seq-lm"]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, encoder_seq_length = common_inputs["input_ids"].shape
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
encoder_shape = (
batch,
num_encoder_attention_heads,
encoder_seq_length,
self._config.hidden_size // num_encoder_attention_heads,
)
decoder_past_length = decoder_seq_length + 3
decoder_shape = (
batch,
num_decoder_attention_heads,
decoder_past_length,
self._config.hidden_size // num_decoder_attention_heads,
)

common_inputs["decoder_attention_mask"] = torch.cat(
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
)

common_inputs["past_key_values"] = []
# If the number of encoder and decoder layers are present in the model configuration, both are considered
num_encoder_layers, num_decoder_layers = self.num_layers
min_num_layers = min(num_encoder_layers, num_decoder_layers)
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

for _ in range(min_num_layers):
common_inputs["past_key_values"].append(
(
torch.zeros(decoder_shape),
torch.zeros(decoder_shape),
torch.zeros(encoder_shape),
torch.zeros(encoder_shape),
)
)

# TODO: test this.
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
for _ in range(min_num_layers, max_num_layers):
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))

elif self.task == "causal-lm":
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
num_encoder_layers, _ = self.num_layers
num_encoder_attention_heads, _ = self.num_attention_heads
past_shape = (
batch,
num_encoder_attention_heads,
past_key_values_length,
self._config.hidden_size // num_encoder_attention_heads,
)

common_inputs["attention_mask"] = torch.cat(
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
common_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
]
else:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)

return common_inputs

def _flatten_past_key_values_(self, flattened_output, name, idx, t):
if self.task in ["default", "seq2seq-lm"]:
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
else:
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
flattened_output, name, idx, t
)
4 changes: 4 additions & 0 deletions src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("token_type_ids", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
66 changes: 22 additions & 44 deletions src/transformers/models/gpt2/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# limitations under the License.
""" OpenAI GPT-2 configuration """
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from typing import Any, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType, is_torch_available

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...onnx import OnnxConfigWithPast
from ...utils import logging


Expand Down Expand Up @@ -194,36 +194,29 @@ def __init__(


class GPT2OnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
for i in range(self._config.n_layer * 2):
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}

common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

return common_inputs

@property
def num_layers(self) -> int:
return self._config.n_layer
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
if self.use_past:
for i in range(self._config.n_layer * 2):
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}

@property
def num_attention_heads(self) -> int:
return self._config.n_head
return common_outputs

return common_outputs

def generate_dummy_inputs(
self,
Expand All @@ -233,9 +226,7 @@ def generate_dummy_inputs(
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)

# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
Expand All @@ -247,27 +238,14 @@ def generate_dummy_inputs(
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
batch = common_inputs["input_ids"].shape[0]
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
(
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
)
for _ in range(self._config.n_layer)
]

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)

return ordered_inputs

@property
def default_onnx_opset(self) -> int:
return 13
Loading