diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 8bfa2ee740f5..c1da4eb288e8 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1398,6 +1398,7 @@ def forward( masked_lm_loss = None if labels is not None: + labels = labels.to(lm_logits.device) loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) @@ -1553,6 +1554,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" @@ -1896,6 +1898,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 6c6bfd431528..859e019ac953 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2581,6 +2581,7 @@ def forward( masked_lm_loss = None if labels is not None: + labels = labels.to(lm_logits.device) loss_fct = CrossEntropyLoss() masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) @@ -2735,6 +2736,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index dad8b7e38929..6ab2de61e9f9 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1596,6 +1596,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index aae00db06917..b123e4e0d5d3 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1563,6 +1563,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index bf6651431480..3f1806fb9a8c 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1098,6 +1098,8 @@ def forward( loss = None if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -1318,6 +1320,7 @@ def forward( mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) lm_loss = None if labels is not None: + labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() @@ -1569,6 +1572,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 23075a055f6c..37ba6825bb2e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1715,6 +1715,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 778d132ac3f0..ceb6aa67a40a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1528,6 +1528,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" @@ -1866,6 +1867,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index bfbe0b2aba0d..7d0652ff617b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1694,6 +1694,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index f00cc8caa19d..9ec9bc71433a 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1499,6 +1499,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) if self.config.problem_type is None: if self.config.num_labels == 1: self.config.problem_type = "regression" @@ -1713,6 +1714,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))