diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3706c6c0fe06..8bfa2ee740f5 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -87,18 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -920,8 +922,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ad22199e66e6..6c6bfd431528 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -78,18 +78,21 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -2101,8 +2104,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 5096336e39bc..178e67f8e1bc 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -43,18 +43,20 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -470,8 +472,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index b65a74983685..dad8b7e38929 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -77,18 +77,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -855,8 +857,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 69e1b069e77a..aae00db06917 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -73,18 +73,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -850,8 +852,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 10524c70bc42..1d0451add50d 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -230,18 +230,20 @@ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch. # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1281,8 +1283,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index cc9e01e4aa5a..6d79536627fc 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -39,18 +39,20 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -481,8 +483,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4ed8d2e5c167..f02cd209de43 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -73,18 +73,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1002,8 +1004,11 @@ def forward( combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None and combined_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index c2befbc1819a..23075a055f6c 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -75,18 +75,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -861,8 +863,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0bf3fb62f769..fe12c82dcb7b 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -80,18 +80,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -907,8 +909,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 0d787eda2980..6b2bdcb93d5a 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -91,18 +91,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1044,8 +1046,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 5bfe44e900f1..0d2572d9d4e7 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -78,18 +78,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1349,8 +1351,11 @@ def forward( combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None and combined_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 110b81721657..6ecedc2173e7 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -64,18 +64,21 @@ ] -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -533,8 +536,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index e39ed97daa17..bfbe0b2aba0d 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -74,18 +74,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -881,8 +883,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 44f3818a2478..f5375b5079b3 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -92,18 +92,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1147,8 +1149,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 60122c5d0f14..f00cc8caa19d 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -79,18 +79,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -883,8 +885,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index d08863f8353f..b5037faf4c9f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -64,18 +64,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -893,8 +895,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 319589eab144..cfb86143b881 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -44,18 +44,20 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -501,8 +503,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 01ac79af472e..c977414a5263 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -82,18 +82,20 @@ def shift_spectrograms_right(input_values: torch.Tensor): # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1552,8 +1554,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 286315a19615..812a8d6b4d01 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -225,18 +225,20 @@ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -1007,8 +1009,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index e6853d0c5a8e..7e2e11c00e22 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -44,18 +44,20 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -524,8 +526,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d6ec39168cd1..cdf769871542 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -69,18 +69,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -917,8 +919,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 8a19557932f9..64cec28edf2c 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -114,18 +114,20 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) - mask_cond = torch.arange(mask.size(-1)) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) @@ -578,8 +580,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]