Skip to content
Merged
196 changes: 96 additions & 100 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,56 +208,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)

# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
# Calculate new tokens to generate
min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
if max_new_tokens == 0:
return input_ids, None

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
# Update past key values and masks
self._update_past_and_masks(input_ids)
# Generate candidates
generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
candidate_ids, candidate_logits = self._generate_candidates(generation_args)
return candidate_ids, candidate_logits

def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
Expand Down Expand Up @@ -318,6 +277,55 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F

self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold

def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]:
"""Calculate the minimum and maximum number of new tokens to generate."""
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
return min_new_tokens, max_new_tokens

def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool:
"""Update past key values and attention masks for subsequent generation rounds."""
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
)
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
return has_past_key_values

def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict:
"""Prepare arguments for the generation call."""
return {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""Generate candidate sequences using the assistant model."""
assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits


class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""
Expand Down Expand Up @@ -367,6 +375,7 @@ def __init__(

self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_target_ids_len: Optional[int] = None
self.prev_assistant_ids = None
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
Expand Down Expand Up @@ -497,27 +506,50 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
return input_ids, None

input_ids = input_ids.to(self.assistant_model.device)
remove_from_pkv = 0

assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
self.prev_assistant_ids = assistant_input_ids

min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)

self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids)

# Update state
self.prev_target_ids_len = input_ids.shape[1]
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_assistant_ids = assistant_output.sequences

if self.prev_target_ids_len >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None

def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]:
"""Converts target input IDs to assistant input IDs, handling discrepancies."""
convert_kwargs = {
"source_tokenizer": self.target_tokenizer,
"destination_tokenizer": self.assistant_tokenizer,
}
remove_from_pkv = 0

# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
# (one for each conversion) which mark where to start looking for the overlap between the
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind

new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
)
prompt_use_length = new_assistant_ids.shape[1]
prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]

discrepancy_length, new_tokens_only, discrepancy_only = (
AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
prompt_use, new_assistant_ids
)
assistant_input_ids = self.prev_assistant_ids

Expand All @@ -538,58 +570,29 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
else:
# edge case: in case of no intersection between prompt and new_assistant_ids
assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)

else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids_len = input_ids.shape[1]

self.prev_assistant_ids = assistant_input_ids
new_cur_len = assistant_input_ids.shape[-1]
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)

# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1 - remove_from_pkv
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1

self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: assistant_input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}

self.assistant_kwargs.pop("attention_mask", None)

assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
return assistant_input_ids, remove_from_pkv

def _process_assistant_outputs(
self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor
) -> torch.LongTensor:
"""Processes assistant outputs to obtain target input IDs."""
num_prev_assistant = self.prev_assistant_ids.shape[1]
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
if start_assistant_look_index < 0:
start_assistant_look_index = 0

new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
assistant_output.sequences[:, start_assistant_look_index:],
assistant_sequences[:, start_assistant_look_index:],
source_tokenizer=self.assistant_tokenizer,
destination_tokenizer=self.target_tokenizer,
)
target_prompt_use_length = new_target_ids_from_window.shape[1]

target_prompt_use = input_ids[:, -target_prompt_use_length:]

_, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
target_prompt_use, new_target_ids_from_window
)
_, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)

new_target_ids = input_ids

Expand All @@ -603,14 +606,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
if hasattr(self.generation_config, "max_length"):
new_target_ids = new_target_ids[:, : self.generation_config.max_length]

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

# 4. Prepare variables for output
if input_ids.shape[1] >= new_target_ids.shape[1]:
return input_ids, None

return new_target_ids, None
return new_target_ids


class PromptLookupCandidateGenerator(CandidateGenerator):
Expand Down
43 changes: 43 additions & 0 deletions tests/generation/test_candidate_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest

import numpy as np

from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers


class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase):
def test_no_intersection(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[4, 5, 6]])
result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens)
self.assertEqual(result, (None, None, None))

def test_complete_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_partial_overlap(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[2, 3, 4, 5]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))

def test_no_new_tokens(self):
prompt = np.array([[1, 2, 3]])
prompt_plus_new_tokens = np.array([[1, 2, 3]])
discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
prompt, prompt_plus_new_tokens
)
self.assertEqual(discrep_length, 0)
np.testing.assert_array_equal(new_tokens_only, np.array([[]]))
np.testing.assert_array_equal(discrep_only, np.array([[]]))