Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- BigBird-Pegasus
- Blenderbot
- BlenderbotSmall
- BLOOM
- CamemBERT
- CodeGen
- ConvBERT
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@


_import_structure = {
"configuration_bloom": [
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BloomConfig",
],
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
}
try:
if not is_tokenizers_available():
Expand All @@ -51,7 +48,7 @@
]

if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig

try:
if not is_tokenizers_available():
Expand Down
95 changes: 95 additions & 0 deletions src/transformers/models/bloom/configuration_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bloom configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional

from transformers import is_torch_available


if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, TensorType

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging


Expand Down Expand Up @@ -153,3 +163,88 @@ def __init__(
self.slow_but_exact = slow_but_exact

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


class BloomOnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe @michaelbenayoun we should fix this sometime :)

self._config.pad_token_id = 0

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}

return common_inputs

@property
def num_layers(self) -> int:
return self._config.n_layer

@property
def num_attention_heads(self) -> int:
return self._config.n_head

@property
def atol_for_validation(self) -> float:
return 1e-3

def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizer",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)

# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})

# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch

batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
past_key_values_length,
self.num_attention_heads,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)

return ordered_inputs

@property
def default_onnx_opset(self) -> int:
return 13
22 changes: 8 additions & 14 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=


def attention_mask_func(attention_scores, attention_mask, causal_mask):
if attention_mask.dtype == torch.bool:
attention_mask_bool = ~attention_mask
else:
attention_mask_bool = (1 - attention_mask).bool()
attention_mask_bool = ~attention_mask.bool()

query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = (
attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
).bool()
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
padded_causal_mask = torch.logical_or(
attention_mask_bool[:, None, key_length - query_length : key_length, None],
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
)
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
# Make use of floats
return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
Expand Down Expand Up @@ -296,11 +293,8 @@ def forward(self, input, mask, max_positions):
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)

mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ class FeaturesManager:
"seq2seq-lm-with-past",
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
),
"bloom": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls="models.bloom.BloomOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def test_values_override(self):
}

PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-350m"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked that the slow tests pass for this checkpoint? You can run:

RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "bloom"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reminding me. All tests are passing now 🙂

("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
Expand Down