diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ba548dc3d8ea..a9599f8ef306 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -50,7 +50,7 @@ def _compute_mask_indices( mask_length: int, attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, -) -> np.ndarray: +) -> torch.tensor: """ Computes random mask spans for a given shape @@ -68,12 +68,12 @@ def _compute_mask_indices( `__. """ bsz, all_sz = shape - mask = np.full((bsz, all_sz), False) + mask = torch.Tensor(bsz, all_sz).fill_(False) all_num_mask = int( # add a random number for probabilistic rounding mask_prob * all_sz / float(mask_length) - + np.random.rand() + + torch.rand() ) all_num_mask = max(min_masks, all_num_mask) @@ -86,14 +86,14 @@ def _compute_mask_indices( num_mask = int( # add a random number for probabilistic rounding mask_prob * sz / float(mask_length) - + np.random.rand() + + torch.rand() ) num_mask = max(min_masks, num_mask) else: sz = all_sz num_mask = all_num_mask - lengths = np.full(num_mask, mask_length) + lengths = torch.Tensor(num_mask).fill_(mask_length) if sum(lengths) == 0: lengths[0] = min(mask_length, sz - 1) @@ -102,14 +102,15 @@ def _compute_mask_indices( if sz - min_len <= num_mask: min_len = sz - num_mask - 1 - mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) - mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) - mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + mask_idc = torch.randperm(sz - min_len)[:num_mask] + mask_idc = torch.tensor([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) - min_len = min([len(m) for m in mask_idcs]) + mask_idcs.append(torch.unique(mask_idc[mask_idc < sz])) + + min_len = torch.min(mask_idcs) for i, mask_idc in enumerate(mask_idcs): if len(mask_idc) > min_len: - mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask_idc = torch.randperm(mask_idc)[:min_len] mask[i, mask_idc] = True return mask @@ -274,12 +275,7 @@ class Wav2Vec2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, + self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, ): super().__init__() self.embed_dim = embed_dim @@ -563,9 +559,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - attention_mask, + create_custom_forward(layer), hidden_states, attention_mask, ) else: layer_outputs = layer( @@ -582,9 +576,7 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -642,9 +634,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - attention_mask, + create_custom_forward(layer), hidden_states, attention_mask, ) else: layer_outputs = layer( @@ -663,9 +653,7 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -788,12 +776,7 @@ def __init__(self, config): @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( - self, - input_values, - attention_mask=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): """ @@ -862,9 +845,7 @@ def forward( # apply SpecAugment along feature axis if self.config.mask_feature_prob > 0: mask_feature_indices = _compute_mask_indices( - (batch_size, hidden_size), - self.config.mask_feature_prob, - self.config.mask_feature_length, + (batch_size, hidden_size), self.config.mask_feature_prob, self.config.mask_feature_length, ) mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0