@@ -641,7 +641,7 @@ def __init__(
641641 self .target_vocab_size : int = len (self ._target_tokenizer .get_vocab ())
642642 self .filter_value : float = filter_value
643643 self .suppress_tokens_id : int = suppress_tokens_id
644- self ._assistant_to_target_input_ids = self ._get_assistant_to_target_input_ids ()
644+ self ._assistant_to_target_input_ids , self . target_to_assistant_input_ids = self ._get_assistant_to_target_input_ids ()
645645 self ._suppress_input_ids : list [int ] = self ._get_suppress_input_ids ()
646646 self .logits_processors : Optional [LogitsProcessorList ] = None
647647 if len (self ._suppress_input_ids ) > 0 :
@@ -677,10 +677,13 @@ def _get_assistant_to_target_input_ids(self):
677677
678678 max_assistant_index = max (assistant_vocab .values ())
679679 assistant_to_target_input_ids = torch .full ((max_assistant_index + 1 ,), self .suppress_tokens_id , dtype = int )
680- for tok , idx in assistant_vocab .items ():
681- if tok in target_vocab :
682- assistant_to_target_input_ids [idx ] = target_vocab [tok ]
683- return assistant_to_target_input_ids .to (self ._assistant_model_device )
680+ target_to_assistant_input_id : Dict [int , int ] = {}
681+ for tok , assistant_id in assistant_vocab .items ():
682+ target_id = target_vocab .get (tok )
683+ if target_id is not None :
684+ assistant_to_target_input_ids [assistant_id ] = target_id
685+ target_to_assistant_input_ids [target_id ] = assistant_id
686+ return assistant_to_target_input_ids .to (self ._assistant_model_device ), target_to_assistant_input_ids
684687
685688 def _get_suppress_input_ids (self ) -> list [int ]:
686689 """
@@ -864,13 +867,20 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to
864867 new_token_count = 1
865868 target_new_ids = target_input_ids [:, - new_token_count :]
866869
867- # Convert only the new tokens
868- target_new_text = self .target_tokenizer .batch_decode (
869- target_new_ids , skip_special_tokens = True , clean_up_tokenization_spaces = True
870- )
871- assistant_new_ids = self .assistant_tokenizer (target_new_text , add_special_tokens = False , return_tensors = "pt" )[
872- "input_ids"
873- ].to (self .assistant_model .device )
870+ # Convert the new tokens
871+ assistant_new_ids = None
872+ if self ._target_seq_len_with_candidates > 0 :
873+ # we have only one new token and we can directly convert it
874+ assistant_new_ids = self ._atm_translator .target_to_assistant_input_ids .get (target_new_ids [0 ].item ())
875+ if assistant_new_ids is None :
876+ target_new_text = self .target_tokenizer .batch_decode (
877+ target_new_ids , skip_special_tokens = True , clean_up_tokenization_spaces = True
878+ )
879+ assistant_new_ids = self .assistant_tokenizer (
880+ target_new_text , add_special_tokens = False , return_tensors = "pt"
881+ )["input_ids" ].to (self .assistant_model .device )
882+ else :
883+ assistant_new_ids = torch .tensor ([[assistant_new_ids ]], device = self .assistant_model .device )
874884
875885 # Update or initialize assistant IDs
876886 if self ._prev_assistant_ids is None :
0 commit comments