diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py index ab0c0b0d16b9..92ab7dcf9a2f 100644 --- a/src/transformers/modeling_flax_bert.py +++ b/src/transformers/modeling_flax_bert.py @@ -20,7 +20,6 @@ import flax.linen as nn import jax import jax.numpy as jnp -from flax.linen import compact from .configuration_bert import BertConfig from .file_utils import add_start_docstrings @@ -108,13 +107,15 @@ class FlaxBertLayerNorm(nn.Module): """ epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - bias: bool = True - scale: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + bias: bool = True # If True, bias (beta) is added. + scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear + # (also e.g. nn.relu), this can be disabled since the scaling will be + # done by the next layer. bias_init: jnp.ndarray = nn.initializers.zeros scale_init: jnp.ndarray = nn.initializers.ones - @compact + @nn.compact def __call__(self, x): """ Applies layer normalization on the input. It normalizes the activations of the layer for each given example in @@ -123,13 +124,6 @@ def __call__(self, x): Args: x: the inputs - epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - bias: If True, bias (beta) is added. - scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. - bias_init: Initializer for bias, by default, zero. - scale_init: Initializer for scale, by default, one Returns: Normalized inputs (the same shape as inputs). @@ -157,7 +151,7 @@ class FlaxBertEmbedding(nn.Module): hidden_size: int emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) - @compact + @nn.compact def __call__(self, inputs): embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) return jnp.take(embedding, inputs, axis=0) @@ -171,7 +165,7 @@ class FlaxBertEmbeddings(nn.Module): type_vocab_size: int max_length: int - @compact + @nn.compact def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): # Embed @@ -198,7 +192,7 @@ class FlaxBertAttention(nn.Module): num_heads: int head_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( hidden_state, attention_mask @@ -211,7 +205,7 @@ def __call__(self, hidden_state, attention_mask): class FlaxBertIntermediate(nn.Module): output_size: int - @compact + @nn.compact def __call__(self, hidden_state): # TODO: Add ACT2FN reference to change activation function dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) @@ -219,7 +213,7 @@ def __call__(self, hidden_state): class FlaxBertOutput(nn.Module): - @compact + @nn.compact def __call__(self, intermediate_output, attention_output): hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output) @@ -231,7 +225,7 @@ class FlaxBertLayer(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask) intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention) @@ -250,7 +244,7 @@ class FlaxBertLayerCollection(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, inputs, attention_mask): assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" @@ -270,7 +264,7 @@ class FlaxBertEncoder(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): layer = FlaxBertLayerCollection( self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" @@ -279,7 +273,7 @@ def __call__(self, hidden_state, attention_mask): class FlaxBertPooler(nn.Module): - @compact + @nn.compact def __call__(self, hidden_state): cls_token = hidden_state[:, 0] out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) @@ -296,7 +290,7 @@ class FlaxBertModule(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): # Embedding diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py index eea705f3cdd2..551ff8d52561 100644 --- a/src/transformers/modeling_flax_roberta.py +++ b/src/transformers/modeling_flax_roberta.py @@ -19,7 +19,6 @@ import flax.linen as nn import jax import jax.numpy as jnp -from flax.linen import compact from .configuration_roberta import RobertaConfig from .file_utils import add_start_docstrings @@ -108,13 +107,15 @@ class FlaxRobertaLayerNorm(nn.Module): """ epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - bias: bool = True - scale: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + bias: bool = True # If True, bias (beta) is added. + scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear + # (also e.g. nn.relu), this can be disabled since the scaling will be + # done by the next layer. bias_init: jnp.ndarray = nn.initializers.zeros scale_init: jnp.ndarray = nn.initializers.ones - @compact + @nn.compact def __call__(self, x): """ Applies layer normalization on the input. It normalizes the activations of the layer for each given example in @@ -123,13 +124,6 @@ def __call__(self, x): Args: x: the inputs - epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - bias: If True, bias (beta) is added. - scale: If True, multiply by scale (gamma). When the next layer is linear - (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. - bias_init: Initializer for bias, by default, zero. - scale_init: Initializer for scale, by default, one Returns: Normalized inputs (the same shape as inputs). @@ -158,7 +152,7 @@ class FlaxRobertaEmbedding(nn.Module): hidden_size: int emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) - @compact + @nn.compact def __call__(self, inputs): embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) return jnp.take(embedding, inputs, axis=0) @@ -173,7 +167,7 @@ class FlaxRobertaEmbeddings(nn.Module): type_vocab_size: int max_length: int - @compact + @nn.compact def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): # Embed @@ -201,7 +195,7 @@ class FlaxRobertaAttention(nn.Module): num_heads: int head_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( hidden_state, attention_mask @@ -215,7 +209,7 @@ def __call__(self, hidden_state, attention_mask): class FlaxRobertaIntermediate(nn.Module): output_size: int - @compact + @nn.compact def __call__(self, hidden_state): # TODO: Add ACT2FN reference to change activation function dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) @@ -224,7 +218,7 @@ def __call__(self, hidden_state): # Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta class FlaxRobertaOutput(nn.Module): - @compact + @nn.compact def __call__(self, intermediate_output, attention_output): hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output) @@ -236,7 +230,7 @@ class FlaxRobertaLayer(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")( hidden_state, attention_mask @@ -258,7 +252,7 @@ class FlaxRobertaLayerCollection(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, inputs, attention_mask): assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" @@ -279,7 +273,7 @@ class FlaxRobertaEncoder(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, hidden_state, attention_mask): layer = FlaxRobertaLayerCollection( self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" @@ -289,7 +283,7 @@ def __call__(self, hidden_state, attention_mask): # Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta class FlaxRobertaPooler(nn.Module): - @compact + @nn.compact def __call__(self, hidden_state): cls_token = hidden_state[:, 0] out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) @@ -307,7 +301,7 @@ class FlaxRobertaModule(nn.Module): head_size: int intermediate_size: int - @compact + @nn.compact def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): # Embedding