diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index a8ea74dd5cdb..df27b34ddcfa 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -766,6 +766,73 @@ def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, te return codevectors, perplexity +class FlaxWav2Vec2Adapter(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # hidden_states require down-projection if feature dims don't match + if self.config.output_hidden_size != self.config.hidden_size: + self.proj = nn.Dense( + self.config.output_hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + else: + self.proj = self.proj_layer_norm = None + + self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + # down-project hidden_states if required + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = self.layers(hidden_states) + + return hidden_states + + +class FlaxWav2Vec2AdapterLayer(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + features=2 * self.config.output_hidden_size, + kernel_size=(self.config.adapter_kernel_size,), + strides=(self.config.adapter_stride,), + padding=((1, 1),), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.glu(hidden_states, axis=2) + + return hidden_states + + +class FlaxWav2Vec2AdapterLayersCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_adapter_layers) + ] + + def __call__(self, hidden_states): + for conv_layer in self.layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -840,7 +907,9 @@ def __call__( rngs=rngs, ) - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): return self.module._get_feat_extract_output_lengths(input_lengths) @@ -860,6 +929,8 @@ def setup(self): else: raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") + self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None + def __call__( self, input_values, @@ -905,6 +976,9 @@ def __call__( hidden_states = encoder_outputs[0] + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] @@ -915,11 +989,15 @@ def __call__( attentions=encoder_outputs.attentions, ) - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): """ Computes the output length of the convolutional layers """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html @@ -928,6 +1006,10 @@ def _conv_out_length(input_length, kernel_size, stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths @@ -1021,11 +1103,17 @@ def __call__( return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) - def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): + def _get_feat_extract_output_lengths( + self, + input_lengths: Union[jnp.ndarray, int], + add_adapter: Optional[bool] = None, + ): """ Computes the output length of the convolutional layers """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html @@ -1034,6 +1122,10 @@ def _conv_out_length(input_length, kernel_size, stride): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths