diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 149b5d5fffa8..5008c7ed7e9c 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -186,7 +186,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: # Apply the attention mask diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c32b1465a5a7..f8b8290a2048 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: # Apply the attention mask