diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index ba94baaa7d0c..6d8234f9c3fd 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -579,7 +579,7 @@ def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, masked_indices.masked_fill_(padding_mask, value=0.0) # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc. - non_func_mask = ~(padding_mask & special_tokens_mask) + non_func_mask = ~(padding_mask | special_tokens_mask) inputs[masked_indices] = self.tokenizer.mask_token_id labels[~masked_indices] = -100 # We only compute loss on masked tokens