diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py old mode 100755 new mode 100644 index c5e6cc564f09..1051a298d513 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1012,6 +1012,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(pooled_logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression"