diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 07c743c6b79f..90c23fe7f08d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4280,7 +4280,8 @@ def _assisted_decoding( # 1. Fetch candidate sequences from a `CandidateGenerator` candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - + candidate_input_ids = candidate_input_ids.to(self.device) + if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device)