Skip to content
Merged
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
db61d3e
Set tokenizer attribute in GRPO
albertvillanova Apr 9, 2026
fea92ab
Replace self.pad_token_id and self.eos_token_id
albertvillanova Apr 9, 2026
a6ad78d
Use self._tokenizer
albertvillanova Apr 9, 2026
d479816
Simplify self.processing_class with self._tokenizer
albertvillanova Apr 9, 2026
e5b5031
Fix assignment
albertvillanova Apr 9, 2026
319571f
Update tests
albertvillanova Apr 9, 2026
a898089
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
45f50d8
Remove redundant condition
albertvillanova Apr 15, 2026
854ef8d
Remove dead code from tests
albertvillanova Apr 15, 2026
cd10af5
Set tokenizer attribute in DPO
albertvillanova Apr 15, 2026
f5ed9f2
Use self._tokenizer
albertvillanova Apr 15, 2026
6a072e5
Set tokenizer attribute in RLOO
albertvillanova Apr 15, 2026
b4db781
Replace self.pad_token_id and self.eos_token_id
albertvillanova Apr 15, 2026
8a0ac8b
Set tokenizer attribute in SFT
albertvillanova Apr 15, 2026
62508a2
Use self._tokenizer
albertvillanova Apr 15, 2026
8f665e4
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
d01f99d
Use self._tokenizer in DPPO
albertvillanova Apr 15, 2026
2242686
Make parse_response accept only tokenizer
albertvillanova Apr 15, 2026
55ab670
Revert
albertvillanova Apr 15, 2026
67c4ace
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
db1b067
Replace self.pad_token_id and self.eos_token_id in DPPO
albertvillanova Apr 16, 2026
aaed316
Replace self.pad_token_id and self.eos_token_id in GFPO
albertvillanova Apr 16, 2026
9f9347a
Replace self.pad_token_id and self.eos_token_id in GRPOWithReplayBuffer
albertvillanova Apr 16, 2026
e6c73d4
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 16, 2026
5cb4c39
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 17, 2026
59b217e
Pass processing_class to add_response_schema
albertvillanova Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,20 +319,18 @@ def __init__(

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
tokenizer = processing_class.tokenizer
self._tokenizer = processing_class.tokenizer
self._is_vlm = True
self._vision_token_ids_cache = None # populated lazily by _get_vision_token_ids
elif isinstance(processing_class, PreTrainedTokenizerBase):
tokenizer = processing_class
self._tokenizer = processing_class
self._is_vlm = False
self._vision_token_ids_cache = None
else:
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
Expand Down Expand Up @@ -517,15 +515,13 @@ def __init__(
# known chat templates.
# We need `getattr`` until the base class sets a default None value for response_schema
# For VLM processors, check the inner tokenizer too (response_schema lives on the tokenizer)
has_response_schema = getattr(processing_class, "response_schema", None) or (
self._is_vlm and getattr(processing_class.tokenizer, "response_schema", None)
)
has_response_schema = getattr(self._tokenizer, "response_schema", None)
Comment thread
qgallouedec marked this conversation as resolved.
Outdated
if self.tools and not has_response_schema:
processing_class = add_response_schema(processing_class)
self._tokenizer = add_response_schema(self._tokenizer)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we add the response schema by checking the chat template in add_response_schema. But processor.chat_template and processor.tokenizer.chat_template are independent attributes, and processor.apply_chat_template reads only the former, so they can silently diverge if either is mutated. In practice they're populated from the same source files on load and match today for all tested VLMs, but technically, patching one doesn't patch the other, which is why I think we should keep

Suggested change
self._tokenizer = add_response_schema(self._tokenizer)
processing_class = add_response_schema(processing_class)

and extend add_response_schema to processor: #5520

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I understood you said it was OK to pass just tokenizer: #5489 (comment)

  • add_response_schema: tokenizer-only is fine. The processor doesn't expose a schema, and I'd expect schema handling to stay tokenizer-level even if upstream adds it — no chat-template-style side effects.

@albertvillanova albertvillanova Apr 16, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read your PR and I'm not sure of fully understanding it: isn't it overlapping with this PR?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah well actually I was not correct: add_response_schema should actually take the processor, not because of the schema, but because of the template: there is no guaranty that processor.chat_template == processor.tokenizer.chat_template. In other words

# this would be correct in the vast majority of cases
def add_response_schema(tokenizer):
    tokenizer.response_schema = SCHEMAS[tokenizer.chat_template]

# but this is better
def add_response_schema(processor):
    processor.tokenizer.response_schema = SCHEMAS[processor.chat_template]

@albertvillanova albertvillanova Apr 17, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After having thought longer about this, I'm not sure about the logic in transformers' implementation of this...

  • Shouldn't they enforce that both (processor and tokenizer) chat_template's are the same?
  • Does it make sense that a tokenizer.response_schema is aligned with its processor.chat_template, but not with its own tokenizer.chat_template?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
# In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template
# isn't, we replace it at initialization with a training-safe, prefix-preserving template.
if self.tools and not is_chat_template_prefix_preserving(processing_class):
self.chat_template = get_training_chat_template(processing_class)
if self.tools and not is_chat_template_prefix_preserving(self._tokenizer):
self.chat_template = get_training_chat_template(self._tokenizer)
Comment thread
albertvillanova marked this conversation as resolved.
Outdated
else:
self.chat_template = None

Expand Down Expand Up @@ -776,9 +772,9 @@ def cast_outputs_to_original_dtype(module, args, output):
generation_kwargs = {
"max_new_tokens": self.max_completion_length,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": self._tokenizer.pad_token_id,
"bos_token_id": self._tokenizer.bos_token_id,
"eos_token_id": self._tokenizer.eos_token_id,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand Down Expand Up @@ -1366,7 +1362,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
else:
# Regular generation path: left-pad token IDs into tensors
prompt_tensors = [torch.tensor(ids) for ids in prompt_ids]
padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left")
padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left")
attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left")
generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask}
# For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.)
Expand Down Expand Up @@ -1399,7 +1395,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token
is_eos = completion_ids == self.eos_token_id
is_eos = completion_ids == self._tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
Expand Down Expand Up @@ -1455,7 +1451,7 @@ def _get_tool_suffix_ids(self, tool_messages):
# When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to
# EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses
# <turn|>) skip this trimming.
eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id]
eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self._tokenizer.eos_token_id]
if eos_positions:
prefix_ids = prefix_ids[: eos_positions[-1] + 1]

Expand All @@ -1473,7 +1469,6 @@ def _get_vision_token_ids(self):
if self._vision_token_ids_cache is None:
cache = {"vision_start": None, "vision_end": None, "image_pad": None, "video_pad": None}
if self._is_vlm:
tok = self.processing_class.tokenizer
# Try multiple token strings per role to support different VLM families
for name, candidates in {
"vision_start": ["<|vision_start|>", "<|image>"],
Expand All @@ -1482,8 +1477,8 @@ def _get_vision_token_ids(self):
"video_pad": ["<|video_pad|>"],
}.items():
for token_str in candidates:
tid = tok.convert_tokens_to_ids(token_str)
if tid != tok.unk_token_id:
tid = self._tokenizer.convert_tokens_to_ids(token_str)
if tid != self._tokenizer.unk_token_id:
cache[name] = tid
break
self._vision_token_ids_cache = cache
Expand Down Expand Up @@ -1728,9 +1723,7 @@ async def _run_async_tools(async_coros):
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]

# Decode post-tool completions.
post_tool_completions = [
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
]
post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_response now handles both processor and tokenizer, this should be reverted

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_response only needs a tokenizer instance but it had to handle both because we did not have a simple way to pass only tokenizer. Once we implement self._tokenizer in all trainers, parse_response could be simplified to accept only tokenizer instances.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note there are already other functions that only accept tokenizer instances: add_response_schema, is_chat_template_prefix_preserving, get_training_chat_template.

More broadly, the underlying goal of this PR is to centralize the processor/tokenizer disambiguation within processing_class in a single place, so that the rest of the code can rely on a well-defined and consistent interface, with a clear expected class instance.

In that sense, the current change in calling parse_response is an intermediate step toward that simplification, rather than a deviation from it.

@qgallouedec qgallouedec Apr 15, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's revert the VLM support in parse_response then.

For the record, a few things to keep in mind going forward. Accessing the inner tokenizer isn't always safe: for apply_chat_template, the processor does extra work (image token expansion), so processor.apply_chat_template is not just an alias for processor.tokenizer.apply_chat_template.
Applying that to each helper:

  • add_response_schema: tokenizer-only is fine. The processor doesn't expose a schema, and I'd expect schema handling to stay tokenizer-level even if upstream adds it — no chat-template-style side effects.
    EDIT: because we check the chat template in add_response_schema, we should actually pass the processor, even if we only set processor.tokenizer.response_schema

  • parse_response: same story, parse_response only exists on the tokenizer today, so we have no choice. And even if upstream eventually moves it to the processor, I'm not too worried: parsing is not the same kind of operation as applying a chat template, so the image-token caveat doesn't really apply here. Revert VLM support in parse_response #5561

  • is_chat_template_prefix_preserving: this one is actually wrong as-is. It should support VLMs (for the reason above), and the call site change in this PR should and will be reverted:

    - if self.tools and not is_chat_template_prefix_preserving(self._tokenizer):
    + if self.tools and not is_chat_template_prefix_preserving(processing_class):

    We'll extend the function in a follow-up. Support VLM processors in is_chat_template_prefix_preserving #5558

  • get_training_chat_template: contrary to what the type hint suggests (needs updating), it already supports both tokenizer and processor, and should be called on the processor. My rule of thumb: with a processor, never manipulate the inner chat template directly. Accept processor in get_training_chat_template #5560


# Add post-tool completions to the existing completions
for idx in range(len(idxs_with_tool)):
Expand Down Expand Up @@ -1798,22 +1791,13 @@ def _generate(self, prompts: list):

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
parsing_class = self.processing_class
# For VLM processors, propagate response_schema to the inner tokenizer if needed
if self._is_vlm:
if getattr(self.processing_class, "response_schema", None) and not getattr(
self.processing_class.tokenizer, "response_schema", None
):
self.processing_class.tokenizer.response_schema = self.processing_class.response_schema
# parse_response handles VLM processors internally (uses inner tokenizer)
tokenizer = getattr(parsing_class, "tokenizer", parsing_class)
if (
Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5
and isinstance(tokenizer, PreTrainedTokenizerBase)
and hasattr(tokenizer, "response_schema") # attribute not set by default for now
and tokenizer.response_schema is not None # only works if the tokenizer has a schema
and isinstance(self._tokenizer, PreTrainedTokenizerBase)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition is redundant now.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now
and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema
Comment on lines +1710 to +1711

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these 2 conditions can be combined into 1:

Suggested change
and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now
and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema
and getattr(self._tokenizer, "response_schema", None) # only works if the tokenizer has a schema: attribute not set by default for now

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, although I feel it's less readable like this

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current can be read smoothly: "the tokenizer has an attribute "response_schema" and this attribute response_schema is not None"

):
completions = [[parse_response(parsing_class, ids)] for ids in completion_ids]
completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
Expand Down Expand Up @@ -1867,7 +1851,7 @@ def _generate(self, prompts: list):
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

# Identify sequences that terminated with EOS and log their lengths
eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
agg_is_truncated = self.accelerator.gather(is_truncated)
self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
Expand Down Expand Up @@ -1965,7 +1949,7 @@ def _generate_and_score_completions(
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(
prompt_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="left",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -1976,7 +1960,7 @@ def _generate_and_score_completions(
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(
completion_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -2003,7 +1987,7 @@ def _generate_and_score_completions(

# If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
# Mask completion_mask for attention masking
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
Expand Down
Loading