diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 90c393e5098b..959b9763d0bd 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -178,7 +178,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1) ** 0.5) + attn_weights = attn_weights / torch.tensor( + value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c58b0cda6a79..1c61adb10d9f 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -189,7 +189,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1) ** 0.5) + attn_weights = attn_weights / torch.tensor( + value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: