diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 89c57cb913fe..fe634141eca0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4260,9 +4260,10 @@ def _assisted_decoding( while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] - # 1. Fetch candidate sequences from a `CandidateGenerator` + # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device)