Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,73 @@ def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, te
return codevectors, perplexity


class FlaxWav2Vec2Adapter(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
# hidden_states require down-projection if feature dims don't match
if self.config.output_hidden_size != self.config.hidden_size:
self.proj = nn.Dense(
self.config.output_hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
else:
self.proj = self.proj_layer_norm = None

self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) since dropout is not used here, can remove the deterministic arg

# down-project hidden_states if required
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)

hidden_states = self.layers(hidden_states)

return hidden_states


class FlaxWav2Vec2AdapterLayer(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
features=2 * self.config.output_hidden_size,
kernel_size=(self.config.adapter_kernel_size,),
strides=(self.config.adapter_stride,),
padding=((1, 1),),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)

def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.glu(hidden_states, axis=2)

return hidden_states


class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.layers = [
FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_adapter_layers)
]

def __call__(self, hidden_states):
for conv_layer in self.layers:
hidden_states = conv_layer(hidden_states)

return hidden_states


class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -840,7 +907,9 @@ def __call__(
rngs=rngs,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths)


Expand All @@ -860,6 +929,8 @@ def setup(self):
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")

self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None

def __call__(
self,
input_values,
Expand Down Expand Up @@ -905,6 +976,9 @@ def __call__(

hidden_states = encoder_outputs[0]

if self.adapter is not None:
hidden_states = self.adapter(hidden_states)

if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]

Expand All @@ -915,11 +989,15 @@ def __call__(
attentions=encoder_outputs.attentions,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -928,6 +1006,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down Expand Up @@ -1021,11 +1103,17 @@ def __call__(

return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self,
input_lengths: Union[jnp.ndarray, int],
add_adapter: Optional[bool] = None,
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -1034,6 +1122,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down