diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index b4e8d5154a12..7efc908ae0ea 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -239,7 +239,7 @@ Flax), PyTorch, and/or TensorFlow.
| Nystromformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
-| OPT | ❌ | ❌ | ✅ | ❌ | ❌ |
+| OPT | ❌ | ❌ | ✅ | ✅ | ✅ |
| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ |
| Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ |
| PLBart | ✅ | ❌ | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/opt.mdx b/docs/source/en/model_doc/opt.mdx
index 5ce9a58c00a0..04344df56dc7 100644
--- a/docs/source/en/model_doc/opt.mdx
+++ b/docs/source/en/model_doc/opt.mdx
@@ -39,9 +39,28 @@ The original code can be found [here](https://github.com/facebookresearch/metase
[[autodoc]] OPTModel
- forward
-
## OPTForCausalLM
[[autodoc]] OPTForCausalLM
- forward
+## TFOPTModel
+
+[[autodoc]] TFOPTModel
+ - call
+
+## TFOPTForCausalLM
+
+[[autodoc]] TFOPTForCausalLM
+ - call
+
+## FlaxOPTModel
+
+[[autodoc]] FlaxOPTModel
+ - __call__
+
+
+## FlaxOPTForCausalLM
+
+[[autodoc]] FlaxOPTForCausalLM
+ - __call__
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 0afe8588d658..a17253b5e017 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -2213,6 +2213,13 @@
"TFOpenAIGPTPreTrainedModel",
]
)
+ _import_structure["models.opt"].extend(
+ [
+ "TFOPTForCausalLM",
+ "TFOPTModel",
+ "TFOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
@@ -2560,6 +2567,13 @@
]
)
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
+ _import_structure["models.opt"].extend(
+ [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+ )
_import_structure["models.pegasus"].extend(
[
"FlaxPegasusForConditionalGeneration",
@@ -4448,6 +4462,7 @@
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
+ from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.rembert import (
@@ -4717,6 +4732,7 @@
FlaxMBartPreTrainedModel,
)
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
+ from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
FlaxRobertaForCausalLM,
diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py
index 230ffcbce349..5eda5a74162c 100644
--- a/src/transformers/models/auto/modeling_flax_auto.py
+++ b/src/transformers/models/auto/modeling_flax_auto.py
@@ -44,6 +44,7 @@
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("mt5", "FlaxMT5Model"),
+ ("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"),
("roberta", "FlaxRobertaModel"),
("roformer", "FlaxRoFormerModel"),
@@ -129,6 +130,7 @@
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
+ ("opt", "FlaxOPTForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
]
diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py
index 9346252f8ec0..716fd2575bfe 100644
--- a/src/transformers/models/auto/modeling_tf_auto.py
+++ b/src/transformers/models/auto/modeling_tf_auto.py
@@ -60,6 +60,7 @@
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
+ ("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
@@ -151,6 +152,7 @@
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
+ ("opt", "TFOPTForCausalLM"),
("rembert", "TFRemBertForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
diff --git a/src/transformers/models/opt/__init__.py b/src/transformers/models/opt/__init__.py
index 303c6fd4536b..e35d07d1b012 100644
--- a/src/transformers/models/opt/__init__.py
+++ b/src/transformers/models/opt/__init__.py
@@ -17,13 +17,24 @@
# limitations under the License.
from typing import TYPE_CHECKING
-from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_flax_available,
+ is_tf_available,
+ is_tokenizers_available,
+ is_torch_available,
+)
_import_structure = {"configuration_opt": ["OPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OPTConfig"]}
-
-if is_torch_available():
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
_import_structure["modeling_opt"] = [
"OPT_PRETRAINED_MODEL_ARCHIVE_LIST",
"OPTForCausalLM",
@@ -31,13 +42,54 @@
"OPTPreTrainedModel",
]
+try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_tf_opt"] = ["TFOPTForCausalLM", "TFOPTModel", "TFOPTPreTrainedModel"]
+
+try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_flax_opt"] = [
+ "FlaxOPTForCausalLM",
+ "FlaxOPTModel",
+ "FlaxOPTPreTrainedModel",
+ ]
+
if TYPE_CHECKING:
from .configuration_opt import OPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OPTConfig
- if is_torch_available():
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
from .modeling_opt import OPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPTForCausalLM, OPTModel, OPTPreTrainedModel
+ try:
+ if not is_tf_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_tf_opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
+
+ try:
+ if not is_flax_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_flax_opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
+
else:
import sys
diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py
new file mode 100644
index 000000000000..f84d56b0d8b1
--- /dev/null
+++ b/src/transformers/models/opt/modeling_flax_opt.py
@@ -0,0 +1,795 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Flax OPT model."""
+
+from functools import partial
+from typing import Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
+from ...utils import add_start_docstrings, logging
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT
+class FlaxOPTAttention(nn.Module):
+ config: OPTConfig
+ embed_dim: int
+ num_heads: int
+ dropout: float = 0.0
+ causal: bool = False
+ bias: bool = True
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self) -> None:
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ dense = partial(
+ nn.Dense,
+ self.embed_dim,
+ use_bias=self.bias,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
+ self.out_proj = dense()
+
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
+
+ @nn.compact
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slighly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ key_value_states: Optional[jnp.ndarray] = None,
+ attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states)
+ value_states = self.v_proj(key_value_states)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.dropout > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.dropout,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = self._merge_heads(attn_output)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class FlaxOPTDecoderLayer(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self) -> None:
+ self.embed_dim = self.config.hidden_size
+ self.self_attn = FlaxOPTAttention(
+ config=self.config,
+ embed_dim=self.embed_dim,
+ num_heads=self.config.num_attention_heads,
+ dropout=self.config.attention_dropout,
+ causal=True,
+ dtype=self.dtype,
+ )
+ self.do_layer_norm_before = self.config.do_layer_norm_before
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+ self.activation_fn = ACT2FN[self.config.activation_function]
+
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+ self.fc1 = nn.Dense(
+ self.config.ffn_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+ self.fc2 = nn.Dense(
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
+ )
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
+
+ def __call__(
+ self,
+ hidden_states: jnp.ndarray,
+ attention_mask: jnp.ndarray,
+ init_cache: bool = False,
+ output_attentions: bool = True,
+ deterministic: bool = True,
+ ) -> Tuple[jnp.ndarray]:
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ )
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+ hidden_states = residual + hidden_states
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_states = (residual + hidden_states).reshape(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class FlaxOPTDecoderLayerCollection(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ self.layerdrop = self.config.layerdrop
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ deterministic: bool = True,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ):
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ outputs = [hidden_states, all_hidden_states, all_self_attns]
+ return outputs
+
+
+class FlaxOPTLearnedPositionalEmbedding(nn.Embed):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def setup(self):
+ self.offset = 2
+ self.embedding = self.param(
+ "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype
+ )
+
+ def __call__(self, positions):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+
+ return super().__call__(positions + self.offset)
+
+
+class FlaxOPTDecoder(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ offset: int = 2
+
+ def setup(self):
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
+
+ embed_dim = self.config.hidden_size
+ self.padding_idx = self.config.pad_token_id
+ self.max_target_positions = self.config.max_position_embeddings
+
+ self.embed_tokens = nn.Embed(
+ self.config.vocab_size,
+ self.config.word_embed_proj_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
+ self.config.max_position_embeddings,
+ embed_dim,
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ if self.config.word_embed_proj_dim != self.config.hidden_size:
+ self.project_in = nn.Dense(self.config.hidden_size, use_bias=False)
+ self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ input_shape = input_ids.shape
+ input_ids = input_ids.reshape(-1, input_shape[-1])
+
+ inputs_embeds = self.embed_tokens(input_ids)
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ positions = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + positions
+
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
+
+ hidden_state, all_hidden_states, attentions = self.layers(
+ hidden_states,
+ attention_mask,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if self.project_out is not None:
+ hidden_state = self.project_out(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_state,)
+
+ outputs = [hidden_state, all_hidden_states, attentions]
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=all_hidden_states,
+ attentions=attentions,
+ )
+
+
+class FlaxOPTPreTrainedModel(FlaxPreTrainedModel):
+ config_class = OPTConfig
+ base_model_prefix: str = "model"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: OPTConfig,
+ input_shape: Tuple[int] = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ position_ids,
+ return_dict=False,
+ )
+
+ random_params = module_init_outputs["params"]
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ params: dict = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ dropout_rng: PRNGKey = None,
+ deterministic: bool = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ if position_ids is None:
+ position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxOPTAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ return outputs
+
+
+class FlaxOPTModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype)
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ init_cache=False,
+ ):
+
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ init_cache=init_cache,
+ )
+
+ if not return_dict:
+ return decoder_outputs
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ hidden_states=decoder_outputs.hidden_states,
+ attentions=decoder_outputs.attentions,
+ )
+
+
+# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT
+class FlaxOPTModel(FlaxOPTPreTrainedModel):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ module_class = FlaxOPTModule
+
+
+append_call_sample_docstring(
+ FlaxOPTModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
+)
+
+
+@add_start_docstrings(
+ "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLMModule(nn.Module):
+ config: OPTConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.model = FlaxOPTModule(config=self.config, dtype=self.dtype)
+ self.lm_head = nn.Dense(
+ self.config.vocab_size,
+ use_bias=False,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask,
+ position_ids,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ hidden_states = outputs[0]
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
+ else:
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=lm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for
+ autoregressive tasks.
+ """,
+ OPT_START_DOCSTRING,
+)
+class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
+ module_class = FlaxOPTForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxOPTForCausalLM,
+ _TOKENIZER_FOR_DOC,
+ _CHECKPOINT_FOR_DOC,
+ FlaxBaseModelOutput,
+ _CONFIG_FOR_DOC,
+)
diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py
new file mode 100644
index 000000000000..dda422a4362a
--- /dev/null
+++ b/src/transformers/models/opt/modeling_tf_opt.py
@@ -0,0 +1,1048 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 OPT model."""
+
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
+
+# Public API
+from ...modeling_tf_utils import (
+ DUMMY_INPUTS,
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSharedEmbeddings,
+ TFWrappedEmbeddings,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_opt import OPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
+_CONFIG_FOR_DOC = "OPTConfig"
+_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+LARGE_NEGATIVE = -1e8
+
+
+def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
+ mask_cond = tf.range(shape_list(mask)[-1])
+
+ mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
+
+ if past_key_values_length > 0:
+ mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
+
+ return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
+
+
+def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ src_len = shape_list(mask)[1]
+ tgt_len = tgt_len if tgt_len is not None else src_len
+ one_cst = tf.constant(1.0)
+ mask = tf.cast(mask, dtype=one_cst.dtype)
+ expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
+
+ return (one_cst - expanded_mask) * LARGE_NEGATIVE
+
+
+class TFOPTLearnedPositionalEmbedding(TFSharedEmbeddings):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately. Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
+
+ def call(self, attention_mask, past_key_values_length: int = 0):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ attention_mask = tf.cast(attention_mask, tf.int64)
+
+ # create positions depending on attention_mask
+ positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1
+
+ # cut positions if `past_key_values_length` is > 0
+ positions = positions[:, past_key_values_length:]
+
+ return super().call(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT
+class TFOPTAttention(tf.keras.layers.Layer):
+ """Multi-headed attention from "Attention Is All You Need"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = tf.keras.layers.Dropout(dropout)
+ self.head_dim = embed_dim // num_heads
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
+ self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
+ self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
+ self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
+
+ def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
+ return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ key_value_states: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
+ attention_mask: Optional[tf.Tensor] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ training: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, embed_dim = shape_list(hidden_states)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = tf.concat([past_key_value[0], key_states], axis=2)
+ value_states = tf.concat([past_key_value[1], value_states], axis=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
+ key_states = tf.reshape(key_states, proj_shape)
+ value_states = tf.reshape(value_states, proj_shape)
+
+ src_len = shape_list(key_states)[1]
+ attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_weights),
+ [bsz * self.num_heads, tgt_len, src_len],
+ message=(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {shape_list(attn_weights)}"
+ ),
+ )
+
+ if attention_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attention_mask),
+ [bsz, 1, tgt_len, src_len],
+ message=(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {shape_list(attention_mask)}"
+ ),
+ )
+
+ attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
+ attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_weights = stable_softmax(attn_weights, axis=-1)
+
+ if layer_head_mask is not None:
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(layer_head_mask),
+ [self.num_heads],
+ message=(
+ f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
+ f" {shape_list(layer_head_mask)}"
+ ),
+ )
+
+ attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
+ attn_weights, (bsz, self.num_heads, tgt_len, src_len)
+ )
+ attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
+
+ attn_probs = self.dropout(attn_weights, training=training)
+ attn_output = tf.matmul(attn_probs, value_states)
+
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ if tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_output),
+ [bsz * self.num_heads, tgt_len, self.head_dim],
+ message=(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {shape_list(attn_output)}"
+ ),
+ )
+
+ attn_output = tf.transpose(
+ tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
+ )
+ attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+ attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
+
+ return attn_output, attn_weights, past_key_value
+
+
+class TFOPTDecoderLayer(tf.keras.layers.Layer):
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.do_layer_norm_before = config.do_layer_norm_before
+ self.embed_dim = config.hidden_size
+ self.self_attn = TFOPTAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_dropout,
+ name="self_attn",
+ is_decoder=True,
+ )
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+ self.activation_fn = get_tf_activation(config.activation_function)
+
+ self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
+ self.fc1 = tf.keras.layers.Dense(config.ffn_dim, name="fc1")
+ self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
+ self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ layer_head_mask: Optional[tf.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ training: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
+ """
+ Args:
+ hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`tf.Tensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size
+ `(decoder_attention_heads,)`
+ past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ )
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states = residual + hidden_states
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return (hidden_states, self_attn_weights, present_key_value)
+
+
+OPT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+ If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
+ first positional argument :
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+
+
+ Args:
+ config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+class TFOPTPreTrainedModel(TFPreTrainedModel):
+ """
+ TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel
+
+ Args:
+ config: OPTConfig
+ """
+
+ config_class = OPTConfig
+ base_model_prefix = "model"
+
+ @property
+ def dummy_inputs(self):
+ pad_token = 1
+ input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
+ dummy_inputs = {
+ "attention_mask": tf.math.not_equal(input_ids, pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+ @tf.function(
+ input_signature=[
+ {
+ "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
+ "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ output = self.call(inputs)
+
+ return self.serving_output(output)
+
+
+OPT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@keras_serializable
+class TFOPTDecoder(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, load_weight_prefix=None, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.layerdrop = config.layerdrop
+ num_embeddings = config.max_position_embeddings
+
+ self.shared = TFSharedEmbeddings(
+ config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="model.decoder.embed_tokens"
+ )
+
+ self.embed_positions = TFOPTLearnedPositionalEmbedding(
+ num_embeddings,
+ config.hidden_size,
+ name="embed_positions",
+ )
+
+ # set tf scope correctly
+ if load_weight_prefix is None:
+ load_weight_prefix = "decoder.embed_tokens"
+
+ with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
+ pass
+
+ # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
+ embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
+ embed_tokens.vocab_size = self.shared.vocab_size
+ embed_tokens.hidden_size = self.shared.hidden_size
+
+ self.embed_tokens = embed_tokens
+
+ if config.word_embed_proj_dim != config.hidden_size:
+ self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
+ self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
+
+ else:
+ self.project_in = None
+ self.project_out = None
+
+ self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
+ self.dropout = tf.keras.layers.Dropout(config.dropout)
+
+ def get_embed_tokens(self):
+ return self.embed_tokens
+
+ def set_embed_tokens(self, embed_tokens):
+ self.embed_tokens = embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.shared.weight = new_embeddings
+ self.shared.vocab_size = self.shared.weight.shape[0]
+ # retrieve correct absolute scope for embed token wrapper
+ with tf.compat.v1.variable_scope("decoder.embed_tokens") as shared_abs_scope_name:
+ pass
+ # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
+ embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
+ self.set_embed_tokens(embed_tokens)
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
+ # create causal mask
+ # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
+ else:
+ combined_attention_mask = _expand_mask(
+ tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
+ )
+
+ if attention_mask is not None:
+ combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
+
+ return combined_attention_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
+ decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`tf.Tensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is None:
+ attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool)
+
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
+
+ if self.project_in is not None:
+ inputs_embeds = self.project_in(inputs_embeds)
+
+ hidden_states = inputs_embeds + pos_embeds
+ hidden_states = self.dropout(hidden_states, training=training)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ present_key_values = () if use_cache else None
+
+ # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
+ # The tf.debugging asserts are not compliant with XLA then they
+ # have to be disabled in other modes than eager.
+ for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
+ if attn_mask is not None and tf.executing_eagerly():
+ tf.debugging.assert_equal(
+ shape_list(attn_mask)[0],
+ len(self.layers),
+ message=(
+ f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
+ f" {shape_list(attn_mask)[0]}."
+ ),
+ )
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ hidden_states, layer_self_attn, present_key_value = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=head_mask[idx] if head_mask is not None else None,
+ past_key_value=past_key_value,
+ )
+
+ if use_cache:
+ present_key_values += (present_key_value,)
+
+ if output_attentions:
+ all_self_attns += (layer_self_attn,)
+
+ if self.project_out is not None:
+ hidden_states = self.project_out(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ else:
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_values,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+@keras_serializable
+class TFOPTMainLayer(tf.keras.layers.Layer):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.decoder = TFOPTDecoder(config, name="decoder")
+
+ def get_input_embeddings(self):
+ return self.decoder.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.decoder.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.decoder(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare TF OPT Model outputting raw hidden-states without any specific head on top.",
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTModel(TFOPTPreTrainedModel):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_input_embeddings(self):
+ return self.model.decoder.shared
+
+ def set_input_embeddings(self, new_embeddings):
+ self.model.set_input_embeddings(new_embeddings)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFBaseModelOutputWithPast(
+ last_hidden_state=output.last_hidden_state,
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ )
+
+
+@add_start_docstrings(
+ """
+ The OPT Model transformer with a language modeling head on top.
+ """,
+ OPT_START_DOCSTRING,
+)
+@keras_serializable
+class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
+ config_class = OPTConfig
+
+ def __init__(self, config: OPTConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.config = config
+ self.model = TFOPTMainLayer(config, name="model")
+
+ def get_output_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
+ attention_mask = kwargs.get("attention_mask", None)
+
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ inputs = tf.expand_dims(inputs[:, -1], -1)
+
+ return {
+ "input_ids": inputs,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+ @unpack_inputs
+ @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: Optional[TFModelInputType] = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ **kwargs
+ ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import GPT2Tokenizer, TFOPTForCausalLM
+
+ >>> model = TFOPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="tf")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ logits = self.model.decoder.shared(outputs[0], mode="linear")
+
+ loss = None
+ if labels is not None:
+ # shift labels to the left and cut last logit token
+ shifted_logits = logits[:, :-1]
+ labels = labels[:, 1:]
+ loss = self.hf_compute_loss(labels, shifted_logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def serving_output(self, output):
+ pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
+ hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
+ attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
+
+ return TFCausalLMOutputWithPast(
+ past_key_values=pkv,
+ hidden_states=hs,
+ attentions=attns,
+ loss=output.loss,
+ logits=output.logits,
+ )
diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py
index a6c6e7926da1..bc0d43b01b9f 100644
--- a/src/transformers/utils/dummy_flax_objects.py
+++ b/src/transformers/utils/dummy_flax_objects.py
@@ -795,6 +795,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
+class FlaxOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
+class FlaxOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["flax"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["flax"])
+
+
class FlaxPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py
index d2e8d5c0cddb..00965e2f0a17 100644
--- a/src/transformers/utils/dummy_tf_objects.py
+++ b/src/transformers/utils/dummy_tf_objects.py
@@ -1619,6 +1619,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFOPTForCausalLM(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFOPTPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
class TFPegasusForConditionalGeneration(metaclass=DummyObject):
_backends = ["tf"]
diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py
new file mode 100644
index 000000000000..8b4c1333dfd5
--- /dev/null
+++ b/tests/models/opt/test_modeling_flax_opt.py
@@ -0,0 +1,405 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+import numpy as np
+import timeout_decorator # noqa
+
+from transformers import OPTConfig, is_flax_available
+from transformers.testing_utils import require_flax, require_sentencepiece, slow
+
+from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
+from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
+
+
+if is_flax_available():
+ import os
+
+ # The slow tests are often failing with OOM error on GPU
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
+
+ import jax
+ import jax.numpy as jnp
+ from transformers import FlaxOPTForCausalLM, FlaxOPTModel, GPT2Tokenizer
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+
+
+@require_flax
+class FlaxOPTModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ initializer_range=0.02,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.initializer_range = initializer_range
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs(self):
+ input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
+ input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
+
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ is_encoder_decoder=False,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ initializer_range=self.initializer_range,
+ use_cache=False,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def prepare_config_and_inputs_for_common(self):
+ config, inputs_dict = self.prepare_config_and_inputs()
+ return config, inputs_dict
+
+ def check_use_cache_forward(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids = inputs_dict["input_ids"]
+ attention_mask = inputs_dict["attention_mask"]
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4")
+
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ attention_mask=attention_mask,
+ past_key_values=outputs_cache.past_key_values,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
+ max_length = 20
+ model = model_class_name(config)
+
+ input_ids, attention_mask = (
+ inputs_dict["input_ids"],
+ inputs_dict["attention_mask"],
+ )
+
+ attention_mask_cache = jnp.concatenate(
+ [
+ attention_mask,
+ jnp.zeros((attention_mask.shape[0], max_length - attention_mask.shape[1])),
+ ],
+ axis=-1,
+ )
+
+ past_key_values = model.init_cache(input_ids.shape[0], max_length)
+ position_ids = jnp.broadcast_to(
+ jnp.arange(input_ids.shape[-1] - 1)[None, :],
+ (input_ids.shape[0], input_ids.shape[-1] - 1),
+ )
+
+ outputs_cache = model(
+ input_ids[:, :-1],
+ attention_mask=attention_mask_cache,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+ position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
+ outputs_cache_next = model(
+ input_ids[:, -1:],
+ past_key_values=outputs_cache.past_key_values,
+ attention_mask=attention_mask_cache,
+ position_ids=position_ids,
+ )
+
+ outputs = model(input_ids, attention_mask=attention_mask)
+
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
+
+
+@require_flax
+class FlaxOPTModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
+ all_model_classes = (FlaxOPTModel, FlaxOPTForCausalLM) if is_flax_available() else ()
+ all_generative_model_classes = () if is_flax_available() else ()
+
+ def setUp(self):
+ self.model_tester = FlaxOPTModelTester(self)
+
+ def test_use_cache_forward(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
+
+ def test_use_cache_forward_with_attn_mask(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+ for model_class in self.all_model_classes:
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_class_name in self.all_model_classes:
+ model = model_class_name.from_pretrained("facebook/opt-125m")
+ input_ids = np.ones((1, 1)) * model.config.eos_token_id
+ outputs = model(input_ids)
+ self.assertIsNotNone(outputs)
+
+
+@require_sentencepiece
+@require_flax
+class FlaxOPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = FlaxOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = jnp.array([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ output = model(input_ids=input_ids).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = jnp.array(
+ [[-0.2867, -1.9256, -0.3062], [-1.2711, -0.1337, -0.1897], [0.4109, 0.1187, -1.3142]]
+ )
+ self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_flax
+@slow
+class FlaxOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = FlaxOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="jax", padding=True, add_special_tokens=False)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ logits_meta = jnp.array(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=1e-4))
+
+ model = jax.jit(model)
+ logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(axis=-1)
+ self.assertTrue(jnp.allclose(logits, logits_meta, atol=1e-4))
+
+
+@slow
+class FlaxOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want everyone",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ "Paris is the capital of France and Parisdylib",
+ "Computers and mobile phones have taken precedence over",
+ ]
+
+ predicted_outputs = []
+
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="jax").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+ generated_ids = generated_ids[0]
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_jitted_batch_generation(self):
+ model_id = "facebook/opt-125m"
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to thank",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ ]
+ model = FlaxOPTForCausalLM.from_pretrained(model_id)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ inputs = tokenizer(
+ [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ ],
+ return_tensors="jax",
+ padding=True,
+ )
+
+ jit_generate = jax.jit(model.generate)
+
+ output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
+
+ output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
+
+ self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
+
+ # TODO fix in the following PR
+ # def test_batch_generation(self):
+ # model_id = "facebook/opt-350m"
+
+ # tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ # model = FlaxOPTForCausalLM.from_pretrained(model_id)
+
+ # tokenizer.padding_side = "left"
+
+ # # use different length sentences to test batching
+ # sentences = [
+ # "Hello, my dog is a little",
+ # "Today, I",
+ # ]
+
+ # inputs = tokenizer(sentences, return_tensors="jax", padding=True)
+ # input_ids = inputs["input_ids"]
+
+ # outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"], trace=False)
+
+ # inputs_non_padded = tokenizer(sentences[0], return_tensors="jax").input_ids
+ # output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ # num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].sum()
+ # inputs_padded = tokenizer(sentences[1], return_tensors="jax").input_ids
+ # output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ # batch_out_sentence = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
+ # non_padded_sentence = tokenizer.decode(output_non_padded[0][0], skip_special_tokens=True)
+ # padded_sentence = tokenizer.decode(output_padded[0][0], skip_special_tokens=True)
+
+ # expected_output_sentence = [
+ # "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ # "Today, I"
+ # # TODO fix this test in next PR
+ # # "Today, I was in the middle of a conversation with a friend about the",
+ # ]
+ # self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ # # TODO outputs will be similar, fix in next PR
+ # self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py
index 9a6487876550..7b8ff6468909 100644
--- a/tests/models/opt/test_modeling_opt.py
+++ b/tests/models/opt/test_modeling_opt.py
@@ -333,7 +333,7 @@ class OPTGenerationTest(unittest.TestCase):
@property
def prompts(self):
return [
- "Today is a beautiful day and I want to",
+ "Today is a beautiful day and I want",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
@@ -343,7 +343,7 @@ def test_generation_pre_attn_layer_norm(self):
model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [
- "Today is a beautiful day and I want to thank",
+ "Today is a beautiful day and I want everyone",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
@@ -408,7 +408,7 @@ def test_generation_post_attn_layer_norm(self):
model_id = "facebook/opt-350m"
EXPECTED_OUTPUTS = [
- "Today is a beautiful day and I want to share",
+ "Today is a beautiful day and I want to",
"In the city of San Francisco, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py
new file mode 100644
index 000000000000..d34d4f0fc8e6
--- /dev/null
+++ b/tests/models/opt/test_modeling_tf_opt.py
@@ -0,0 +1,414 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers import OPTConfig, is_tf_available
+from transformers.testing_utils import require_sentencepiece, require_tf, slow
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import GPT2Tokenizer, TFOPTForCausalLM, TFOPTModel
+
+
+def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
+ if attention_mask is None:
+ attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@require_tf
+class TFOPTModelTester:
+ config_cls = OPTConfig
+ config_updates = {}
+ hidden_act = "gelu"
+
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ is_training=True,
+ use_labels=False,
+ vocab_size=99,
+ hidden_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=20,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ embed_dim=16,
+ word_embed_proj_dim=16,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.embed_dim = embed_dim
+ self.word_embed_proj_dim = word_embed_proj_dim
+ self.is_encoder_decoder = False
+
+ def prepare_config_and_inputs_for_common(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
+ eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
+ input_ids = tf.concat([input_ids, eos_tensor], axis=1)
+
+ config = self.config_cls(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ ffn_dim=self.intermediate_size,
+ dropout=self.hidden_dropout_prob,
+ attention_dropout=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ eos_token_id=self.eos_token_id,
+ bos_token_id=self.bos_token_id,
+ pad_token_id=self.pad_token_id,
+ embed_dim=self.embed_dim,
+ word_embed_proj_dim=self.word_embed_proj_dim,
+ is_encoder_decoder=False,
+ **self.config_updates,
+ )
+ inputs_dict = prepare_opt_inputs_dict(config, input_ids)
+ return config, inputs_dict
+
+ def check_decoder_model_past_large_inputs(self, config, inputs_dict):
+ model = TFOPTModel(config=config)
+ input_ids = inputs_dict["input_ids"]
+
+ input_ids = input_ids[:1, :]
+ attention_mask = inputs_dict["attention_mask"][:1, :]
+ self.batch_size = 1
+
+ # first forward pass
+ outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
+
+ output, past_key_values = outputs.to_tuple()
+
+ # create hypothetical next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
+ next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
+
+ # append to next input_ids and
+ next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
+ next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
+
+ output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
+ output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
+
+ self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
+
+ # select random slice
+ random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
+ output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
+ output_from_past_slice = output_from_past[:, :, random_slice_idx]
+
+ # test that outputs are equal for slice
+ tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
+
+
+@require_tf
+class TFOPTModelTest(TFModelTesterMixin, unittest.TestCase):
+ all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else ()
+ all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else ()
+ is_encoder_decoder = False
+ test_pruning = False
+ test_onnx = False
+ onnx_min_opset = 10
+
+ def setUp(self):
+ self.model_tester = TFOPTModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=OPTConfig)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_decoder_model_past_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
+ self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_model_common_attributes(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
+
+ if model_class in self.all_generative_model_classes:
+ x = model.get_output_embeddings()
+ assert isinstance(x, tf.keras.layers.Layer)
+ else:
+ x = model.get_output_embeddings()
+ assert x is None
+
+ def test_resize_token_embeddings(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def _get_word_embedding_weight(model, embedding_layer):
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ # Here we build the word embeddings weights if not exists.
+ # And then we retry to get the attribute once built.
+ model(model.dummy_inputs)
+ if hasattr(embedding_layer, "weight"):
+ return embedding_layer.weight
+ else:
+ return None
+
+ for model_class in self.all_model_classes:
+ for size in [config.vocab_size - 10, config.vocab_size + 10]:
+ # build the embeddings
+ model = model_class(config=config)
+ old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # reshape the embeddings
+ model.resize_token_embeddings(size)
+ new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
+ new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
+
+ # check that the resized embeddings size matches the desired size.
+ assert_size = size if size is not None else config.vocab_size
+
+ self.assertEqual(new_input_embeddings.shape[0], assert_size)
+
+ # check that weights remain the same after resizing
+ models_equal = True
+ for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ if old_output_embeddings is not None and new_output_embeddings is not None:
+ self.assertEqual(new_output_embeddings.shape[0], assert_size)
+
+ models_equal = True
+ for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
+ if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
+ models_equal = False
+ self.assertTrue(models_equal)
+
+ def test_saved_model_creation(self):
+ # This test is too long (>30sec) and makes fail the CI
+ pass
+
+
+def _long_tensor(tok_lst):
+ return tf.constant(tok_lst, dtype=tf.int32)
+
+
+@require_tf
+class TFOPTHeadTests(unittest.TestCase):
+ vocab_size = 99
+
+ def _get_config_and_data(self):
+ eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
+ input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
+ batch_size = input_ids.shape[0]
+ config = OPTConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=24,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ ffn_dim=32,
+ max_position_embeddings=48,
+ eos_token_id=2,
+ pad_token_id=1,
+ bos_token_id=0,
+ )
+ return config, input_ids, batch_size
+
+
+@require_sentencepiece
+@require_tf
+class OPTModelIntegrationTests(unittest.TestCase):
+ @slow
+ def test_inference_no_head(self):
+ model = TFOPTModel.from_pretrained("facebook/opt-350m")
+ input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
+ attention_mask = tf.not_equal(input_ids, model.config.pad_token_id)
+ with tf.GradientTape():
+ output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
+ expected_shape = (1, 11, 512)
+ self.assertEqual(output.shape, expected_shape)
+ expected_slice = tf.constant(
+ [[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]]
+ )
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-3))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ output = xla_generate(input_ids, attention_mask)[0]
+ self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
+
+
+@require_tf
+@slow
+class TFOPTEmbeddingsTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.path_model = "facebook/opt-350m"
+
+ def test_logits(self):
+ model = TFOPTForCausalLM.from_pretrained(self.path_model)
+ tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
+
+ prompts = [
+ "Today is a beautiful day and I want to",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+ # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
+ inputs = tokenizer(prompts, return_tensors="tf", padding=True, add_special_tokens=False)
+ logits = tf.math.reduce_mean(model(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ logits_meta = tf.constant(
+ [
+ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
+ [-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
+ [0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
+ [6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
+ ]
+ )
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+ xla_generate = tf.function(model, jit_compile=True)
+ logits = tf.math.reduce_mean(xla_generate(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
+ self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
+
+
+@slow
+class TFOPTGenerationTest(unittest.TestCase):
+ @property
+ def prompts(self):
+ return [
+ "Today is a beautiful day and I want",
+ "In the city of",
+ "Paris is the capital of France and",
+ "Computers and mobile phones have taken",
+ ]
+
+ def test_generation_pre_attn_layer_norm(self):
+ model_id = "facebook/opt-125m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want everyone",
+ "In the city of Rome Canaver Canaver Canaver Canaver",
+ "Paris is the capital of France and Parisdylib",
+ "Computers and mobile phones have taken precedence over",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
+
+ def test_batch_generation(self):
+ model_id = "facebook/opt-350m"
+
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ tokenizer.padding_side = "left"
+
+ # use different length sentences to test batching
+ sentences = [
+ "Hello, my dog is a little",
+ "Today, I",
+ ]
+
+ inputs = tokenizer(sentences, return_tensors="tf", padding=True)
+ input_ids = inputs["input_ids"]
+
+ outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
+
+ inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
+ output_non_padded = model.generate(input_ids=inputs_non_padded)
+
+ num_paddings = inputs_non_padded.shape[-1] - tf.math.reduce_sum(
+ tf.cast(inputs["attention_mask"][-1], tf.int64)
+ )
+ inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
+ output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
+
+ batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
+ padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
+
+ expected_output_sentence = [
+ "Hello, my dog is a little bit of a dork.\nI'm a little bit",
+ "Today, I was in the middle of a conversation with a friend about the",
+ ]
+ self.assertListEqual(expected_output_sentence, batch_out_sentence)
+ self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
+
+ def test_generation_post_attn_layer_norm(self):
+ model_id = "facebook/opt-350m"
+
+ EXPECTED_OUTPUTS = [
+ "Today is a beautiful day and I want to",
+ "In the city of San Francisco, the city",
+ "Paris is the capital of France and the capital",
+ "Computers and mobile phones have taken over the",
+ ]
+
+ predicted_outputs = []
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
+ model = TFOPTForCausalLM.from_pretrained(model_id)
+
+ for prompt in self.prompts:
+ input_ids = tokenizer(prompt, return_tensors="tf").input_ids
+
+ generated_ids = model.generate(input_ids, max_length=10)
+
+ generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ predicted_outputs += generated_string
+
+ self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 3e20dc09e22c..1738d2e92dfc 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -41,6 +41,8 @@ src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py
src/transformers/models/opt/modeling_opt.py
+src/transformers/models/opt/modeling_tf_opt.py
+src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py