diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 954076058e04..086a87750cfb 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -635,10 +635,10 @@ def __init__( self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer self._assistant_model_device: str = assistant_model_device - if target_vocab_size: - self.target_vocab_size: int = target_vocab_size - else: + if target_vocab_size is None: self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) + else: + self.target_vocab_size: int = target_vocab_size self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e818b266cd7b..39a38f9139ec 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1860,15 +1860,14 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): ``` """ - def __init__(self, suppress_tokens, device: str = "cpu", filter_value: float = -float("Inf")): + def __init__(self, suppress_tokens, device: str = "cpu"): self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device) - self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens) - scores = torch.where(suppress_token_mask, self.filter_value, scores) + scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b6221ebe38d5..b8d5b64ced6c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4282,8 +4282,6 @@ def _assisted_decoding( # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) - candidate_input_ids = candidate_input_ids.to(self.device) - candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 7c65f3697425..7e631ade8a02 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -16,44 +16,6 @@ ) -class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): - def test_no_intersection(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[4, 5, 6]]) - result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) - self.assertEqual(result, (None, None, None)) - - def test_complete_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) - - def test_partial_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) - - def test_no_new_tokens(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) - - class TestAssistantToTargetTranslator(unittest.TestCase): def setUp(self): # Create mock tokenizers with predefined vocabularies