diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 098d9356b0c5..1bf73af7d325 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -53,6 +53,7 @@ Ready-made configurations include the following architectures: - BigBird-Pegasus - Blenderbot - BlenderbotSmall +- BLOOM - CamemBERT - CodeGen - ConvBERT diff --git a/src/transformers/models/bloom/__init__.py b/src/transformers/models/bloom/__init__.py index bc509181e0ef..9aea71885883 100644 --- a/src/transformers/models/bloom/__init__.py +++ b/src/transformers/models/bloom/__init__.py @@ -22,10 +22,7 @@ _import_structure = { - "configuration_bloom": [ - "BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", - "BloomConfig", - ], + "configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"], } try: if not is_tokenizers_available(): @@ -51,7 +48,7 @@ ] if TYPE_CHECKING: - from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig + from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig try: if not is_tokenizers_available(): diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index f841d6669965..f7929dee8afa 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -13,7 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Bloom configuration""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, List, Mapping, Optional + +from transformers import is_torch_available + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, TensorType + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast, PatchingSpec from ...utils import logging @@ -153,3 +163,88 @@ def __init__( self.slow_but_exact = slow_but_exact super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class BloomOnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + @property + def atol_for_validation(self) -> float: + return 1e-3 + + def generate_dummy_inputs( + self, + tokenizer: "PreTrainedTokenizer", + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + past_key_values_length, + self.num_attention_heads, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index ba8edde1493f..9f520ae473bf 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= def attention_mask_func(attention_scores, attention_mask, causal_mask): - if attention_mask.dtype == torch.bool: - attention_mask_bool = ~attention_mask - else: - attention_mask_bool = (1 - attention_mask).bool() + attention_mask_bool = ~attention_mask.bool() query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1) - padded_causal_mask = ( - attention_mask_bool[:, None, key_length - query_length : key_length, None] - + ~causal_mask[:, :, key_length - query_length : key_length, :key_length] - ).bool() - padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool() + padded_causal_mask = torch.logical_or( + attention_mask_bool[:, None, key_length - query_length : key_length, None], + ~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(), + ) + padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length]) # Make use of floats return ( attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0), @@ -296,11 +293,8 @@ def forward(self, input, mask, max_positions): mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device) mask = mask.to(input.device) - causal_mask = ( - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) - .view(1, 1, max_positions, max_positions) - .to(input.device) - ) + seq_ids = torch.arange(max_positions, device=input.device) + causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device) mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 5f9eba893857..e63668ead9bb 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -182,6 +182,15 @@ class FeaturesManager: "seq2seq-lm-with-past", onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig", ), + "bloom": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "sequence-classification", + "token-classification", + onnx_config_cls="models.bloom.BloomOnnxConfig", + ), "camembert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 50601598f5aa..4836be5ad11e 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -204,6 +204,7 @@ def test_values_override(self): } PYTORCH_EXPORT_WITH_PAST_MODELS = { + ("bloom", "bigscience/bloom-350m"), ("gpt2", "gpt2"), ("gpt-neo", "EleutherAI/gpt-neo-125M"), }