diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 199148a13cd..bbb861eec14 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -177,9 +177,8 @@ def _make_trainer(self): trainer.processing_class = SimpleNamespace( batch_decode=MagicMock(return_value=["decoded"]), ) + trainer._tokenizer = SimpleNamespace(eos_token_id=2, pad_token_id=0) trainer.tools = None - trainer.eos_token_id = 2 - trainer.pad_token_id = 0 trainer._metrics = { "train": { "num_tokens": [], diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index 260417a1aa6..fb566d2234b 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -333,7 +333,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): return prompt_ids, completion_ids, sampled_logprobs, topk_logprobs, topk_token_ids else: prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} for key, value in multimodal_fields.items(): @@ -396,7 +396,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): topk_logps_chunks.append(topk_lp_t.cpu()) # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id has_eos = is_eos.any(dim=1) eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos] @@ -592,9 +592,7 @@ async def _run_async_tools(async_coros): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids] for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] @@ -668,13 +666,12 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and hasattr(tokenizer, "response_schema") # attribute not set by default for now - and tokenizer.response_schema is not None # only works if the tokenizer has a schema + and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now + and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -717,7 +714,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -898,7 +895,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -909,7 +906,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -951,7 +948,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) # Mask completion_mask for attention masking completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index e5faa5fdce2..c5695be3cf0 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -118,7 +118,7 @@ def _generate_and_score_completions(self, inputs): prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -132,7 +132,7 @@ def _generate_and_score_completions(self, inputs): completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -155,7 +155,7 @@ def _generate_and_score_completions(self, inputs): # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index e6b26e1778c..523363d689a 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -116,7 +116,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -130,7 +130,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -161,7 +161,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() @@ -665,7 +665,7 @@ def update_with_replay_buffer( if target_prompt_len > current_batch_prompt_seq_len: prompt_ids = pad( list(prompt_ids.unbind(0)), - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_prompt_len, padding_side="left", ) @@ -676,7 +676,7 @@ def update_with_replay_buffer( if target_completion_len > current_batch_completion_seq_len: completion_ids = pad( list(completion_ids.unbind(0)), - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_completion_len, padding_side="right", ) @@ -711,7 +711,7 @@ def update_with_replay_buffer( if sampled_data["prompt_ids"][i].size(1) < target_prompt_len: sampled_data["prompt_ids"][i] = pad( sampled_data["prompt_ids"][i], - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_prompt_len, padding_side="left", ) @@ -726,7 +726,7 @@ def update_with_replay_buffer( if sampled_data["completion_ids"][i].size(1) < target_completion_len: sampled_data["completion_ids"][i] = pad( sampled_data["completion_ids"][i], - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_completion_len, padding_side="right", ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0a7f7912c9c..df8b93efefa 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -557,16 +557,16 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( @@ -629,16 +629,16 @@ def __init__( if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - if pad_token not in tokenizer.get_vocab(): + pad_token = args.pad_token or self._tokenizer.pad_token or self._tokenizer.eos_token + if pad_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - tokenizer.pad_token = pad_token + self._tokenizer.pad_token = pad_token data_collator = DataCollatorForPreference( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=args.max_length, truncation_mode=args.truncation_mode, pad_to_multiple_of=args.pad_to_multiple_of, @@ -893,8 +893,7 @@ def add_eos(example, eos_token): example["rejected"] = example["rejected"] + eos_token return example - eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token - dataset = dataset.map(add_eos, fn_kwargs={"eos_token": eos_token}, **map_kwargs) + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": self._tokenizer.eos_token}, **map_kwargs) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 295770890b1..0f400e14fcd 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -319,33 +319,31 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + # Resolve vision placeholder token IDs once. Used by the forward pass to rebuild mm_token_type_ids # when tool responses inject images into the completion (see _generate forward_kwargs block). self._image_pad_token_id = None self._video_pad_token_id = None if self._is_vlm: for candidate in ("<|image_pad|>", "<|image|>"): - tid = tokenizer.convert_tokens_to_ids(candidate) - if tid != tokenizer.unk_token_id: + tid = self._tokenizer.convert_tokens_to_ids(candidate) + if tid != self._tokenizer.unk_token_id: self._image_pad_token_id = tid break - tid = tokenizer.convert_tokens_to_ids("<|video_pad|>") - if tid != tokenizer.unk_token_id: + tid = self._tokenizer.convert_tokens_to_ids("<|video_pad|>") + if tid != self._tokenizer.unk_token_id: self._video_pad_token_id = tid - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id - if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " @@ -528,8 +526,7 @@ def __init__( # While waiting for broader adoption, we provide this utility function to manually set the response schema for # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer # method that reads `self.response_schema`. - tokenizer = processing_class.tokenizer if self._is_vlm else processing_class - if self.tools and getattr(tokenizer, "response_schema", None) is None: + if self.tools and getattr(self._tokenizer, "response_schema", None) is None: processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -785,9 +782,9 @@ def cast_outputs_to_original_dtype(module, args, output): generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -1370,7 +1367,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): else: # Regular generation path: left-pad token IDs into tensors prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) @@ -1403,7 +1400,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) @@ -1459,7 +1456,7 @@ def _get_tool_suffix_ids(self, tool_messages): # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses # ) skip this trimming. - eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id] + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self._tokenizer.eos_token_id] if eos_positions: prefix_ids = prefix_ids[: eos_positions[-1] + 1] @@ -1658,9 +1655,7 @@ async def _run_async_tools(async_coros): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions. - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): @@ -1710,13 +1705,12 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and hasattr(tokenizer, "response_schema") # attribute not set by default for now - and tokenizer.response_schema is not None # only works if the tokenizer has a schema + and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now + and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -1770,7 +1764,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -1868,7 +1862,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1879,7 +1873,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1906,7 +1900,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) # Mask completion_mask for attention masking completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6924332a336..05facb61528 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -267,16 +267,14 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token if is_peft_available() and is_peft_model(model) and peft_config is not None: raise ValueError( @@ -532,9 +530,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -993,7 +991,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): else: # Regular generation path: left-pad token IDs into tensors prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) @@ -1026,7 +1024,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) @@ -1073,7 +1071,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -1127,7 +1125,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1138,7 +1136,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1148,7 +1146,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] # Mask completion_mask for attention masking is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 68e78ec48bd..9d3c6286060 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -702,22 +702,22 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") if args.eos_token is not None: - if args.eos_token not in tokenizer.get_vocab(): + if args.eos_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `eos_token` ('{args.eos_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " "in the vocabulary before using it as an EOS token." ) - tokenizer.eos_token = args.eos_token + self._tokenizer.eos_token = args.eos_token if args.chat_template_path is not None: if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): @@ -880,16 +880,16 @@ def __init__( if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - if pad_token not in tokenizer.get_vocab(): + pad_token = args.pad_token or self._tokenizer.pad_token or self._tokenizer.eos_token + if pad_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - tokenizer.pad_token = pad_token + self._tokenizer.pad_token = pad_token data_collator = DataCollatorForLanguageModeling( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=None if self.padding_free else args.max_length, truncation_mode=args.truncation_mode, completion_only_loss=self.completion_only_loss, @@ -1115,10 +1115,9 @@ def add_eos(example, eos_token): example["completion"] = example["completion"] + eos_token return example - eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token dataset = dataset.map( add_eos, - fn_kwargs={"eos_token": eos_token}, + fn_kwargs={"eos_token": self._tokenizer.eos_token}, remove_columns="messages" if "messages" in column_names else None, # renamed to "text" **map_kwargs, )