Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,31 @@ def _process_assistant_outputs(

class AssistantToTargetTranslator:
"""
Translate the assistant into the target universe.
Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
(https://www.arxiv.org/abs/2502.05202).
It maintains mappings between the two vocabularies and handles token/logit conversion.

Args:
target_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the target (main) model.
assistant_tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used by the assistant model.
assistant_model_device (`str`, defaults to "cpu"):
The device where the assistant model is located. Used for placing tensors.
target_vocab_size (`int`, *optional*):
The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
"""
FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.

def __init__(
self,
target_tokenizer: "PreTrainedTokenizerBase",
assistant_tokenizer: "PreTrainedTokenizerBase",
assistant_model_device: str = "cpu",
target_vocab_size: Optional[int] = None,
filter_value: float = -float("Inf"),
suppress_tokens_id: int = -1,
):
self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer
self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer
Expand All @@ -638,8 +652,6 @@ def __init__(
self.target_vocab_size: int = len(self._target_tokenizer.get_vocab())
else:
self.target_vocab_size: int = target_vocab_size
self.filter_value: float = filter_value
self.suppress_tokens_id: int = suppress_tokens_id
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
self._get_assistant_to_target_input_ids()
)
Expand Down Expand Up @@ -677,7 +689,7 @@ def _get_assistant_to_target_input_ids(self):
}

max_assistant_index = max(assistant_vocab.values())
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int)
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
target_to_assistant_input_ids: Dict[int, int] = {}
for tok, assistant_id in assistant_vocab.items():
target_id = target_vocab.get(tok)
Expand All @@ -690,7 +702,7 @@ def _get_suppress_input_ids(self) -> list[int]:
"""
Get the input ids that are in the assistant vocab but not in the target vocab.
"""
return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0]
return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]

def get_target_ids(
self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
Expand All @@ -714,9 +726,9 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT
"""

target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device)
target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device)
# Mask for valid indices
assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id
assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
# Exclude invalid indices
target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
Expand Down
53 changes: 26 additions & 27 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,33 +855,32 @@ def _get_candidate_generator(
max_length=generation_config.max_length,
)
elif different_tokenizers:
match generation_config.do_sample:
case True:
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
# required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab()
target_vocab_size=self.config.vocab_size,
)
case False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
case _:
raise ValueError(f"Invalid value for `do_sample`: {generation_config.do_sample}")
if generation_config.do_sample is True:
candidate_generator = UniversalSpeculativeDecodingGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
# required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab()
target_vocab_size=self.config.vocab_size,
)
elif generation_config.do_sample is False:
candidate_generator = AssistedCandidateGeneratorDifferentTokenizers(
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
target_tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)
else:
raise ValueError(f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}")
else:
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
Expand Down
8 changes: 4 additions & 4 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):

def test_get_assistant_to_target_input_ids(self):
"""Test the mapping from assistant tokens to target tokens."""
expected_mapping = [0, 1, 2, self.translator.suppress_tokens_id, self.translator.suppress_tokens_id]
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID]
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
self.assertEqual(actual_mapping, expected_mapping)

Expand All @@ -56,7 +56,7 @@ def test_get_target_ids(self):
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]) # 'hello world foo baz' in assistant tokenizer

expected_target_ids = torch.LongTensor(
[[0, 1, 2, self.translator.suppress_tokens_id]]
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)

actual_target_ids = self.translator.get_target_ids(
Expand All @@ -67,10 +67,10 @@ def test_get_target_ids(self):
def test_get_target_logits(self):
"""Test the conversion of assistant logits to target logits."""
# Assistant logits for IDs 0, 1, 2
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.filter_value]]]) # Shape (1, 1, 5)
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]) # Shape (1, 1, 5)

# Expected target logits (target_vocab_size = 4)
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.filter_value)
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE)
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
expected_target_logits[0, 0, 1] = 0.2 # 'world'
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
Expand Down