-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Set _tokenizer as trainer attribute #5489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
db61d3e
fea92ab
a6ad78d
d479816
e5b5031
319571f
a898089
45f50d8
854ef8d
cd10af5
f5ed9f2
6a072e5
b4db781
8a0ac8b
62508a2
8f665e4
d01f99d
2242686
55ab670
67c4ace
db1b067
aaed316
9f9347a
e6c73d4
5cb4c39
59b217e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -319,20 +319,18 @@ def __init__( | |||||||
|
|
||||||||
| # Handle pad token for processors or tokenizers | ||||||||
| if isinstance(processing_class, ProcessorMixin): | ||||||||
| tokenizer = processing_class.tokenizer | ||||||||
| self._tokenizer = processing_class.tokenizer | ||||||||
| self._is_vlm = True | ||||||||
| self._vision_token_ids_cache = None # populated lazily by _get_vision_token_ids | ||||||||
| elif isinstance(processing_class, PreTrainedTokenizerBase): | ||||||||
| tokenizer = processing_class | ||||||||
| self._tokenizer = processing_class | ||||||||
| self._is_vlm = False | ||||||||
| self._vision_token_ids_cache = None | ||||||||
| else: | ||||||||
| raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") | ||||||||
|
|
||||||||
| if tokenizer.pad_token is None: | ||||||||
| tokenizer.pad_token = tokenizer.eos_token | ||||||||
| self.pad_token_id = tokenizer.pad_token_id | ||||||||
| self.eos_token_id = tokenizer.eos_token_id | ||||||||
| if self._tokenizer.pad_token is None: | ||||||||
| self._tokenizer.pad_token = self._tokenizer.eos_token | ||||||||
|
cursor[bot] marked this conversation as resolved.
|
||||||||
|
|
||||||||
| if is_peft_available() and is_peft_model(model) and peft_config is not None: | ||||||||
| raise ValueError( | ||||||||
|
|
@@ -517,15 +515,13 @@ def __init__( | |||||||
| # known chat templates. | ||||||||
| # We need `getattr`` until the base class sets a default None value for response_schema | ||||||||
| # For VLM processors, check the inner tokenizer too (response_schema lives on the tokenizer) | ||||||||
| has_response_schema = getattr(processing_class, "response_schema", None) or ( | ||||||||
| self._is_vlm and getattr(processing_class.tokenizer, "response_schema", None) | ||||||||
| ) | ||||||||
| has_response_schema = getattr(self._tokenizer, "response_schema", None) | ||||||||
|
qgallouedec marked this conversation as resolved.
Outdated
|
||||||||
| if self.tools and not has_response_schema: | ||||||||
| processing_class = add_response_schema(processing_class) | ||||||||
| self._tokenizer = add_response_schema(self._tokenizer) | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we add the response schema by checking the chat template in
Suggested change
and extend
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I understood you said it was OK to pass just tokenizer: #5489 (comment)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I read your PR and I'm not sure of fully understanding it: isn't it overlapping with this PR?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah well actually I was not correct: # this would be correct in the vast majority of cases
def add_response_schema(tokenizer):
tokenizer.response_schema = SCHEMAS[tokenizer.chat_template]
# but this is better
def add_response_schema(processor):
processor.tokenizer.response_schema = SCHEMAS[processor.chat_template]
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After having thought longer about this, I'm not sure about the logic in transformers' implementation of this...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done.
cursor[bot] marked this conversation as resolved.
Outdated
|
||||||||
| # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template | ||||||||
| # isn't, we replace it at initialization with a training-safe, prefix-preserving template. | ||||||||
| if self.tools and not is_chat_template_prefix_preserving(processing_class): | ||||||||
| self.chat_template = get_training_chat_template(processing_class) | ||||||||
| if self.tools and not is_chat_template_prefix_preserving(self._tokenizer): | ||||||||
| self.chat_template = get_training_chat_template(self._tokenizer) | ||||||||
|
albertvillanova marked this conversation as resolved.
Outdated
|
||||||||
| else: | ||||||||
| self.chat_template = None | ||||||||
|
|
||||||||
|
|
@@ -776,9 +772,9 @@ def cast_outputs_to_original_dtype(module, args, output): | |||||||
| generation_kwargs = { | ||||||||
| "max_new_tokens": self.max_completion_length, | ||||||||
| "do_sample": True, | ||||||||
| "pad_token_id": tokenizer.pad_token_id, | ||||||||
| "bos_token_id": tokenizer.bos_token_id, | ||||||||
| "eos_token_id": tokenizer.eos_token_id, | ||||||||
| "pad_token_id": self._tokenizer.pad_token_id, | ||||||||
| "bos_token_id": self._tokenizer.bos_token_id, | ||||||||
| "eos_token_id": self._tokenizer.eos_token_id, | ||||||||
| "temperature": self.temperature, | ||||||||
| "top_p": self.top_p, | ||||||||
| "top_k": self.top_k, | ||||||||
|
|
@@ -1366,7 +1362,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): | |||||||
| else: | ||||||||
| # Regular generation path: left-pad token IDs into tensors | ||||||||
| prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] | ||||||||
| padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") | ||||||||
| padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") | ||||||||
| attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") | ||||||||
| generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} | ||||||||
| # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) | ||||||||
|
|
@@ -1399,7 +1395,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): | |||||||
| completion_ids = prompt_completion_ids[:, prompt_length:] | ||||||||
|
|
||||||||
| # Mask everything after the first EOS token | ||||||||
| is_eos = completion_ids == self.eos_token_id | ||||||||
| is_eos = completion_ids == self._tokenizer.eos_token_id | ||||||||
| eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) | ||||||||
| eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] | ||||||||
| sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) | ||||||||
|
|
@@ -1455,7 +1451,7 @@ def _get_tool_suffix_ids(self, tool_messages): | |||||||
| # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to | ||||||||
| # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses | ||||||||
| # <turn|>) skip this trimming. | ||||||||
| eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id] | ||||||||
| eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self._tokenizer.eos_token_id] | ||||||||
| if eos_positions: | ||||||||
| prefix_ids = prefix_ids[: eos_positions[-1] + 1] | ||||||||
|
|
||||||||
|
|
@@ -1473,7 +1469,6 @@ def _get_vision_token_ids(self): | |||||||
| if self._vision_token_ids_cache is None: | ||||||||
| cache = {"vision_start": None, "vision_end": None, "image_pad": None, "video_pad": None} | ||||||||
| if self._is_vlm: | ||||||||
| tok = self.processing_class.tokenizer | ||||||||
| # Try multiple token strings per role to support different VLM families | ||||||||
| for name, candidates in { | ||||||||
| "vision_start": ["<|vision_start|>", "<|image>"], | ||||||||
|
|
@@ -1482,8 +1477,8 @@ def _get_vision_token_ids(self): | |||||||
| "video_pad": ["<|video_pad|>"], | ||||||||
| }.items(): | ||||||||
| for token_str in candidates: | ||||||||
| tid = tok.convert_tokens_to_ids(token_str) | ||||||||
| if tid != tok.unk_token_id: | ||||||||
| tid = self._tokenizer.convert_tokens_to_ids(token_str) | ||||||||
| if tid != self._tokenizer.unk_token_id: | ||||||||
| cache[name] = tid | ||||||||
| break | ||||||||
| self._vision_token_ids_cache = cache | ||||||||
|
|
@@ -1728,9 +1723,7 @@ async def _run_async_tools(async_coros): | |||||||
| completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] | ||||||||
|
|
||||||||
| # Decode post-tool completions. | ||||||||
| post_tool_completions = [ | ||||||||
| parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids | ||||||||
| ] | ||||||||
| post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids] | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note there are already other functions that only accept tokenizer instances: More broadly, the underlying goal of this PR is to centralize the processor/tokenizer disambiguation within processing_class in a single place, so that the rest of the code can rely on a well-defined and consistent interface, with a clear expected class instance. In that sense, the current change in calling
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, let's revert the VLM support in For the record, a few things to keep in mind going forward. Accessing the inner tokenizer isn't always safe: for
|
||||||||
|
|
||||||||
| # Add post-tool completions to the existing completions | ||||||||
| for idx in range(len(idxs_with_tool)): | ||||||||
|
|
@@ -1798,22 +1791,13 @@ def _generate(self, prompts: list): | |||||||
|
|
||||||||
| # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. | ||||||||
| if is_conversational({"prompt": prompts[0]}): | ||||||||
| parsing_class = self.processing_class | ||||||||
| # For VLM processors, propagate response_schema to the inner tokenizer if needed | ||||||||
| if self._is_vlm: | ||||||||
| if getattr(self.processing_class, "response_schema", None) and not getattr( | ||||||||
| self.processing_class.tokenizer, "response_schema", None | ||||||||
| ): | ||||||||
| self.processing_class.tokenizer.response_schema = self.processing_class.response_schema | ||||||||
| # parse_response handles VLM processors internally (uses inner tokenizer) | ||||||||
| tokenizer = getattr(parsing_class, "tokenizer", parsing_class) | ||||||||
| if ( | ||||||||
| Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 | ||||||||
| and isinstance(tokenizer, PreTrainedTokenizerBase) | ||||||||
| and hasattr(tokenizer, "response_schema") # attribute not set by default for now | ||||||||
| and tokenizer.response_schema is not None # only works if the tokenizer has a schema | ||||||||
| and isinstance(self._tokenizer, PreTrainedTokenizerBase) | ||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this condition is redundant now.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep
cursor[bot] marked this conversation as resolved.
Outdated
|
||||||||
| and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now | ||||||||
| and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema | ||||||||
|
Comment on lines
+1710
to
+1711
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these 2 conditions can be combined into 1:
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, although I feel it's less readable like this
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the current can be read smoothly: "the tokenizer has an attribute "response_schema" and this attribute response_schema is not None" |
||||||||
| ): | ||||||||
| completions = [[parse_response(parsing_class, ids)] for ids in completion_ids] | ||||||||
| completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids] | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| else: | ||||||||
| contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) | ||||||||
| completions = [[{"role": "assistant", "content": content}] for content in contents] | ||||||||
|
|
@@ -1867,7 +1851,7 @@ def _generate(self, prompts: list): | |||||||
| self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) | ||||||||
|
|
||||||||
| # Identify sequences that terminated with EOS and log their lengths | ||||||||
| eos_and_pad = [self.eos_token_id, self.pad_token_id] | ||||||||
| eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] | ||||||||
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) | ||||||||
| agg_is_truncated = self.accelerator.gather(is_truncated) | ||||||||
| self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) | ||||||||
|
|
@@ -1965,7 +1949,7 @@ def _generate_and_score_completions( | |||||||
| prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] | ||||||||
| prompt_ids = pad( | ||||||||
| prompt_ids, | ||||||||
| padding_value=self.pad_token_id, | ||||||||
| padding_value=self._tokenizer.pad_token_id, | ||||||||
| padding_side="left", | ||||||||
| pad_to_multiple_of=self.pad_to_multiple_of, | ||||||||
| ).to(device=device) | ||||||||
|
|
@@ -1976,7 +1960,7 @@ def _generate_and_score_completions( | |||||||
| completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] | ||||||||
| completion_ids = pad( | ||||||||
| completion_ids, | ||||||||
| padding_value=self.pad_token_id, | ||||||||
| padding_value=self._tokenizer.pad_token_id, | ||||||||
| padding_side="right", | ||||||||
| pad_to_multiple_of=self.pad_to_multiple_of, | ||||||||
| ).to(device=device) | ||||||||
|
|
@@ -2003,7 +1987,7 @@ def _generate_and_score_completions( | |||||||
|
|
||||||||
| # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking | ||||||||
| if self.mask_truncated_completions: | ||||||||
| eos_and_pad = [self.eos_token_id, self.pad_token_id] | ||||||||
| eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] | ||||||||
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) | ||||||||
| # Mask completion_mask for attention masking | ||||||||
| completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.