-
Notifications
You must be signed in to change notification settings - Fork 0
Fix prepare + apply #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6097a8d
71562fc
3b4e9da
1dcdae4
10d1e56
f9a260f
0d3310d
98cd50b
8260624
ff7977e
38d81b1
6a7d3b3
e4e53b9
a19a9de
c4e4186
0ec0788
200f7a0
95bfa2c
1cbc871
4a94849
f1b6b08
df68533
35e354a
a558bd0
5c3ad58
811a4e5
2dcc9ed
f2be0da
83b8250
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -568,7 +568,7 @@ class AssistantToTargetTranslator: | |
| def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase"): | ||
| self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer | ||
| self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer | ||
| self._assistant_to_target_input_ids: dict[int, int] = self._get_assistant_to_target_input_ids() | ||
| self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() | ||
| self.suppress_input_ids: list[int] = self._get_suppress_input_ids() | ||
| self.logits_processors: LogitsProcessorList = LogitsProcessorList( | ||
| [ | ||
|
|
@@ -577,22 +577,21 @@ def __init__(self, target_tokenizer: "PreTrainedTokenizerBase", assistant_tokeni | |
| ] | ||
| ) | ||
|
|
||
| def _get_assistant_to_target_input_ids(self) -> dict[int, int]: | ||
| """ | ||
| Get a mapping from assistant tokens to target tokens based on vocabularies. | ||
| """ | ||
| def _get_assistant_to_target_input_ids(self): | ||
| target_vocab = self._target_tokenizer.get_vocab() | ||
| assistant_vocab = self._assistant_tokenizer.get_vocab() | ||
| return { | ||
| assistant_vocab[tok]: target_vocab[tok] for tok in set(target_vocab.keys()) & set(assistant_vocab.keys()) | ||
| } | ||
|
|
||
| max_assistant_index = max(assistant_vocab.values()) | ||
| assistant_to_target_input_ids = torch.full((max_assistant_index+1,), -1, dtype=int) # -1 means not in target vocab | ||
| for tok, idx in assistant_vocab.items(): | ||
| if tok in target_vocab: | ||
| assistant_to_target_input_ids[idx] = target_vocab[tok] | ||
| return assistant_to_target_input_ids | ||
|
|
||
| def _get_suppress_input_ids(self) -> list[int]: | ||
| """ | ||
| Get the input ids that are in the assistant vocab but not in the target vocab. | ||
| """ | ||
| assistant_vocab = self._assistant_tokenizer.get_vocab() | ||
| return list(set(assistant_vocab.values()) - set(self._assistant_to_target_input_ids.keys())) | ||
| return torch.where(self._assistant_to_target_input_ids==-1)[0] | ||
|
|
||
| def get_target_ids( | ||
| self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor | ||
|
|
@@ -602,14 +601,16 @@ def get_target_ids( | |
| Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens. | ||
| Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. | ||
| """ | ||
| device = assistant_candidate_ids.device | ||
| target_candidate_ids = ( | ||
| assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :] | ||
| .cpu() | ||
| .apply_(lambda x: self._assistant_to_target_input_ids.get(x, x)) | ||
| .to(device) | ||
| ) | ||
| return torch.cat((target_input_ids, target_candidate_ids.unsqueeze(0)), dim=1) | ||
|
|
||
| i = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] | ||
| if i == 0: | ||
| return target_input_ids | ||
| else: | ||
| device = assistant_candidate_ids.device | ||
| #assert len(assistant_candidate_ids[0]) > assistant_input_ids.shape[1] | ||
| transformed_slice = self._assistant_to_target_input_ids[assistant_candidate_ids[0, -(len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]) :].cpu()].to(device) | ||
| #assert torch.all(transformed_slice != -1) | ||
| return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1) | ||
|
|
||
| def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor: | ||
| """ | ||
|
|
@@ -622,12 +623,8 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT | |
| assistant_logits_supported_mask: torch.BoolTensor = assistant_logits > -float("inf") | ||
| assistant_logits_supported_indices: torch.IntTensor = assistant_logits_supported_mask.nonzero(as_tuple=True)[ | ||
| -1 | ||
| ] | ||
| target_logits_supported_indices: torch.IntTensor = ( | ||
| assistant_logits_supported_indices.cpu() | ||
| .apply_(lambda x: self._assistant_to_target_input_ids[x]) | ||
| .to(device) | ||
| ) | ||
| ].cpu() | ||
| target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_logits_supported_indices].to(device) | ||
| target_logits[..., target_logits_supported_indices] = assistant_logits[..., assistant_logits_supported_mask] | ||
| return target_logits | ||
|
|
||
|
|
@@ -708,7 +705,7 @@ def __init__( | |
| logits_processor, | ||
| ) | ||
| # Track sequence lengths and previous assistant IDs | ||
| self._prev_target_seq_len: int = 0 | ||
| self._candidates_target_seq_len: int = 0 | ||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._prev_assistant_ids: Optional[torch.LongTensor] = None | ||
|
|
||
| def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: | ||
|
|
@@ -732,15 +729,17 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, | |
| generation_args["generation_config"].return_dict_in_generate = True | ||
|
|
||
| # Generate and process outputs using translator | ||
| generation_args['logits_processor'] = self._atm_translator.logits_processors | ||
| assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) | ||
| self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values | ||
|
|
||
| candidate_logits = torch.stack(assistant_output.scores, dim=1) | ||
|
|
||
| # Use translator to convert tokens and logits | ||
| candidate_ids = assistant_output.sequences | ||
| candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) | ||
| self._prev_assistant_ids = candidate_ids | ||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) | ||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._candidates_target_seq_len = target_ids.shape[-1] | ||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| target_logits = self._atm_translator.get_target_logits(candidate_logits) | ||
|
|
||
| return target_ids, target_logits | ||
|
|
@@ -751,9 +750,11 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to | |
| """ | ||
| # Calculate new tokens since last call | ||
| target_seq_len = target_input_ids.shape[-1] | ||
| new_token_count = target_seq_len - self._prev_target_seq_len | ||
| if self._candidates_target_seq_len == 0: | ||
| new_token_count = target_seq_len | ||
| else: | ||
| new_token_count = 1 | ||
| target_new_ids = target_input_ids[:, -new_token_count:] | ||
| self._prev_target_seq_len = target_seq_len | ||
|
|
||
| # Convert only the new tokens | ||
| target_new_text = self.target_tokenizer.batch_decode( | ||
|
|
@@ -765,11 +766,15 @@ def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> to | |
|
|
||
| # Update or initialize assistant IDs | ||
| if self._prev_assistant_ids is None: | ||
| self._prev_assistant_ids = assistant_new_ids | ||
| assistant_input_ids = assistant_new_ids | ||
| else: | ||
| self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) | ||
|
|
||
| return self._prev_assistant_ids | ||
| i = self._candidates_target_seq_len+1-target_seq_len | ||
| if i > 0: | ||
| self._prev_assistant_ids = self._prev_assistant_ids[:,:-i] | ||
jmamou marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) | ||
| assistant_input_ids = assistant_input_ids.to(torch.int) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the documentation,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean adding before
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wdyt about ensuring we only assign
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we get all the IDs from the tokenizer and their type is |
||
|
|
||
| return assistant_input_ids | ||
|
|
||
|
|
||
| class PromptLookupCandidateGenerator(CandidateGenerator): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.