Skip to content

Commit a556947

Browse files
jmamoukeyboardAnt
andauthored
BUG fix in _prepare_assistant_input_ids (huggingface#14)
* fix _prepare_assistant_input_ids * target_to_assistant_input_ids * Update src/transformers/generation/candidate_generator.py Co-authored-by: Nadav Timor <[email protected]> --------- Co-authored-by: Nadav Timor <[email protected]>
1 parent a2a2882 commit a556947

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/transformers/generation/candidate_generator.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def __init__(
641641
self.target_vocab_size: int = len(self._target_tokenizer.get_vocab())
642642
self.filter_value: float = filter_value
643643
self.suppress_tokens_id: int = suppress_tokens_id
644-
self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids()
644+
self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = self._get_assistant_to_target_input_ids()
645645
self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
646646
self.logits_processors: Optional[LogitsProcessorList] = None
647647
if len(self._suppress_input_ids) > 0:
@@ -677,10 +677,13 @@ def _get_assistant_to_target_input_ids(self):
677677

678678
max_assistant_index = max(assistant_vocab.values())
679679
assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int)
680-
for tok, idx in assistant_vocab.items():
681-
if tok in target_vocab:
682-
assistant_to_target_input_ids[idx] = target_vocab[tok]
683-
return assistant_to_target_input_ids.to(self._assistant_model_device)
680+
target_to_assistant_input_id: Dict[int, int] = {}
681+
for tok, assistant_id in assistant_vocab.items():
682+
target_id = target_vocab.get(tok)
683+
if target_id is not None:
684+
assistant_to_target_input_ids[assistant_id] = target_id
685+
target_to_assistant_input_ids[target_id] = assistant_id
686+
return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
684687

685688
def _get_suppress_input_ids(self) -> list[int]:
686689
"""
@@ -864,13 +867,20 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to
864867
new_token_count = 1
865868
target_new_ids = target_input_ids[:, -new_token_count:]
866869

867-
# Convert only the new tokens
868-
target_new_text = self.target_tokenizer.batch_decode(
869-
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
870-
)
871-
assistant_new_ids = self.assistant_tokenizer(target_new_text, add_special_tokens=False, return_tensors="pt")[
872-
"input_ids"
873-
].to(self.assistant_model.device)
870+
# Convert the new tokens
871+
assistant_new_ids = None
872+
if self._target_seq_len_with_candidates > 0:
873+
# we have only one new token and we can directly convert it
874+
assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
875+
if assistant_new_ids is None:
876+
target_new_text = self.target_tokenizer.batch_decode(
877+
target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
878+
)
879+
assistant_new_ids = self.assistant_tokenizer(
880+
target_new_text, add_special_tokens=False, return_tensors="pt"
881+
)["input_ids"].to(self.assistant_model.device)
882+
else:
883+
assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)
874884

875885
# Update or initialize assistant IDs
876886
if self._prev_assistant_ids is None:

0 commit comments

Comments
 (0)