diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 6385a33af3ad..35bf54a9a612 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1522,6 +1522,7 @@ def forward( loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: + labels = labels.to(logits.device) logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() @@ -1757,6 +1758,7 @@ def forward( loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: + labels = labels.to(logits.device) logits = logits[:, -labels.size(1) :, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c3f5285441bc..1ae05846922c 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -850,6 +850,7 @@ def forward( loss = None if labels is not None: + labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression"