Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
db61d3e
Set tokenizer attribute in GRPO
albertvillanova Apr 9, 2026
fea92ab
Replace self.pad_token_id and self.eos_token_id
albertvillanova Apr 9, 2026
a6ad78d
Use self._tokenizer
albertvillanova Apr 9, 2026
d479816
Simplify self.processing_class with self._tokenizer
albertvillanova Apr 9, 2026
e5b5031
Fix assignment
albertvillanova Apr 9, 2026
319571f
Update tests
albertvillanova Apr 9, 2026
a898089
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
45f50d8
Remove redundant condition
albertvillanova Apr 15, 2026
854ef8d
Remove dead code from tests
albertvillanova Apr 15, 2026
cd10af5
Set tokenizer attribute in DPO
albertvillanova Apr 15, 2026
f5ed9f2
Use self._tokenizer
albertvillanova Apr 15, 2026
6a072e5
Set tokenizer attribute in RLOO
albertvillanova Apr 15, 2026
b4db781
Replace self.pad_token_id and self.eos_token_id
albertvillanova Apr 15, 2026
8a0ac8b
Set tokenizer attribute in SFT
albertvillanova Apr 15, 2026
62508a2
Use self._tokenizer
albertvillanova Apr 15, 2026
8f665e4
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
d01f99d
Use self._tokenizer in DPPO
albertvillanova Apr 15, 2026
2242686
Make parse_response accept only tokenizer
albertvillanova Apr 15, 2026
55ab670
Revert
albertvillanova Apr 15, 2026
67c4ace
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 15, 2026
db1b067
Replace self.pad_token_id and self.eos_token_id in DPPO
albertvillanova Apr 16, 2026
aaed316
Replace self.pad_token_id and self.eos_token_id in GFPO
albertvillanova Apr 16, 2026
9f9347a
Replace self.pad_token_id and self.eos_token_id in GRPOWithReplayBuffer
albertvillanova Apr 16, 2026
e6c73d4
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 16, 2026
5cb4c39
Merge remote-tracking branch 'upstream/main' into set-tokenizer-attri…
albertvillanova Apr 17, 2026
59b217e
Pass processing_class to add_response_schema
albertvillanova Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ def _make_trainer(self):
trainer.processing_class = SimpleNamespace(
batch_decode=MagicMock(return_value=["decoded"]),
)
trainer._tokenizer = SimpleNamespace(eos_token_id=2, pad_token_id=0)
trainer.tools = None
trainer.eos_token_id = 2
trainer.pad_token_id = 0
trainer._metrics = {
"train": {
"num_tokens": [],
Expand Down
23 changes: 10 additions & 13 deletions trl/experimental/dppo/dppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
return prompt_ids, completion_ids, sampled_logprobs, topk_logprobs, topk_token_ids
else:
prompt_tensors = [torch.tensor(ids) for ids in prompt_ids]
padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left")
padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left")
attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left")
generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask}
for key, value in multimodal_fields.items():
Expand Down Expand Up @@ -396,7 +396,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
topk_logps_chunks.append(topk_lp_t.cpu())

# Mask everything after the first EOS token
is_eos = completion_ids == self.eos_token_id
is_eos = completion_ids == self._tokenizer.eos_token_id
has_eos = is_eos.any(dim=1)
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos]
Expand Down Expand Up @@ -592,9 +592,7 @@ async def _run_async_tools(async_coros):
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]

# Decode post-tool completions
post_tool_completions = [
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
]
post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids]

for idx in range(len(idxs_with_tool)):
idx_with_tool = idxs_with_tool[idx]
Expand Down Expand Up @@ -668,13 +666,12 @@ def _generate(self, prompts: list):

# Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
if is_conversational({"prompt": prompts[0]}):
tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class
if (
Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5
and hasattr(tokenizer, "response_schema") # attribute not set by default for now
and tokenizer.response_schema is not None # only works if the tokenizer has a schema
and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now
and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema
):
completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids]
completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids]
else:
contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in contents]
Expand Down Expand Up @@ -717,7 +714,7 @@ def _generate(self, prompts: list):
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
agg_is_truncated = self.accelerator.gather(is_truncated)
self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
Expand Down Expand Up @@ -898,7 +895,7 @@ def _generate_and_score_completions(
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(
prompt_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="left",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -909,7 +906,7 @@ def _generate_and_score_completions(
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(
completion_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand Down Expand Up @@ -951,7 +948,7 @@ def _generate_and_score_completions(

# If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
# Mask completion_mask for attention masking
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
Expand Down
6 changes: 3 additions & 3 deletions trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _generate_and_score_completions(self, inputs):
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(
prompt_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="left",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -132,7 +132,7 @@ def _generate_and_score_completions(self, inputs):
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(
completion_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -155,7 +155,7 @@ def _generate_and_score_completions(self, inputs):

# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _generate_and_score_completions(
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(
prompt_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="left",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand All @@ -130,7 +130,7 @@ def _generate_and_score_completions(
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(
completion_ids,
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
).to(device=device)
Expand Down Expand Up @@ -161,7 +161,7 @@ def _generate_and_score_completions(

# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()

Expand Down Expand Up @@ -665,7 +665,7 @@ def update_with_replay_buffer(
if target_prompt_len > current_batch_prompt_seq_len:
prompt_ids = pad(
list(prompt_ids.unbind(0)),
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
pad_to_multiple_of=target_prompt_len,
padding_side="left",
)
Expand All @@ -676,7 +676,7 @@ def update_with_replay_buffer(
if target_completion_len > current_batch_completion_seq_len:
completion_ids = pad(
list(completion_ids.unbind(0)),
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
Expand Down Expand Up @@ -711,7 +711,7 @@ def update_with_replay_buffer(
if sampled_data["prompt_ids"][i].size(1) < target_prompt_len:
sampled_data["prompt_ids"][i] = pad(
sampled_data["prompt_ids"][i],
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
pad_to_multiple_of=target_prompt_len,
padding_side="left",
)
Expand All @@ -726,7 +726,7 @@ def update_with_replay_buffer(
if sampled_data["completion_ids"][i].size(1) < target_completion_len:
sampled_data["completion_ids"][i] = pad(
sampled_data["completion_ids"][i],
padding_value=self.pad_token_id,
padding_value=self._tokenizer.pad_token_id,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
Expand Down
19 changes: 9 additions & 10 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,16 @@ def __init__(

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
tokenizer = processing_class.tokenizer
self._tokenizer = processing_class.tokenizer
self._is_vlm = True
elif isinstance(processing_class, PreTrainedTokenizerBase):
tokenizer = processing_class
self._tokenizer = processing_class
self._is_vlm = False
else:
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token

if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
Expand Down Expand Up @@ -629,16 +629,16 @@ def __init__(
if data_collator is None and not self._is_vision_dataset:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
if pad_token not in tokenizer.get_vocab():
pad_token = args.pad_token or self._tokenizer.pad_token or self._tokenizer.eos_token
if pad_token not in self._tokenizer.get_vocab():
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
tokenizer.pad_token = pad_token
self._tokenizer.pad_token = pad_token
data_collator = DataCollatorForPreference(
pad_token_id=tokenizer.pad_token_id,
pad_token_id=self._tokenizer.pad_token_id,
max_length=args.max_length,
truncation_mode=args.truncation_mode,
pad_to_multiple_of=args.pad_to_multiple_of,
Expand Down Expand Up @@ -893,8 +893,7 @@ def add_eos(example, eos_token):
example["rejected"] = example["rejected"] + eos_token
return example

eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token
dataset = dataset.map(add_eos, fn_kwargs={"eos_token": eos_token}, **map_kwargs)
dataset = dataset.map(add_eos, fn_kwargs={"eos_token": self._tokenizer.eos_token}, **map_kwargs)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
Expand Down
Loading
Loading