Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/transformers/models/parakeet/modeling_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig


@dataclass
@auto_docstring(
custom_intro="""
Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
"""
)
class ParakeetEncoderModelOutput(BaseModelOutput):
attention_mask: Optional[torch.Tensor] = None


class ParakeetEncoderRelPositionalEncoding(nn.Module):
"""Relative positional encoding for Parakeet."""

Expand Down Expand Up @@ -513,9 +523,13 @@ def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attention_mask: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
output_attention_mask (`bool`, *optional*):
Whether to return the output attention mask.

Example:

```python
Expand Down Expand Up @@ -546,8 +560,8 @@ def forward(
)

if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
attention_mask = attention_mask & attention_mask.transpose(1, 2)
attention_mask = attention_mask.unsqueeze(1)

Expand All @@ -567,7 +581,9 @@ def forward(
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)
return ParakeetEncoderModelOutput(
last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we return lengths directly instead or along with attention_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep an explicit approach here for potiential future usage of the same with left padding, you can retreive lengths by doing attention_mask.sum(-1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good.

)


@dataclass
Expand Down
22 changes: 19 additions & 3 deletions src/transformers/models/parakeet/modular_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig


@dataclass
@auto_docstring(
custom_intro="""
Extends [~modeling_outputs.BaseModelOutput] to include the output attention mask since sequence length is not preserved in the model's forward.
"""
)
class ParakeetEncoderModelOutput(BaseModelOutput):
attention_mask: Optional[torch.Tensor] = None


class ParakeetEncoderRelPositionalEncoding(nn.Module):
"""Relative positional encoding for Parakeet."""

Expand Down Expand Up @@ -399,9 +409,13 @@ def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attention_mask: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
output_attention_mask (`bool`, *optional*):
Whether to return the output attention mask.

Example:

```python
Expand Down Expand Up @@ -432,8 +446,8 @@ def forward(
)

if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
output_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = output_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
attention_mask = attention_mask & attention_mask.transpose(1, 2)
attention_mask = attention_mask.unsqueeze(1)

Expand All @@ -453,7 +467,9 @@ def forward(
**kwargs,
)

return BaseModelOutput(last_hidden_state=hidden_states)
return ParakeetEncoderModelOutput(
last_hidden_state=hidden_states, attention_mask=output_mask.int() if output_attention_mask else None
)


@dataclass
Expand Down