From ffde2d002ab617a870a983fde4bf8ead6f5d0f31 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 18 Mar 2026 15:32:53 +0100 Subject: [PATCH 1/9] Use `VLLMGeneration` in `GOLDTrainer` --- trl/experimental/gold/gold_config.py | 31 ++ trl/experimental/gold/gold_trainer.py | 441 +++----------------------- 2 files changed, 76 insertions(+), 396 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index d2c5fe72fb5..c32ef45a3e8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -93,6 +93,17 @@ class GOLDConfig(SFTConfig): Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`). vllm_structured_outputs_regex (`str`, *optional*): Regex for vLLM structured outputs for the student model. + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need + to change it. + vllm_max_model_length (`int`, *optional*): + Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the + model's maximum context length. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation backend to use in vLLM. Use `"vllm"` (default) or `"transformers"`. vllm_sync_frequency (`int`, *optional*, defaults to `1`): Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after every step. @@ -296,6 +307,12 @@ class GOLDConfig(SFTConfig): "help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).' }, ) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": 'Base URL for the vLLM server (e.g., "http://localhost:8001"). If provided, vllm_server_host and vllm_server_port are ignored.' + }, + ) vllm_server_host: str = field( default="0.0.0.0", metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'}, @@ -308,6 +325,10 @@ class GOLDConfig(SFTConfig): default=240.0, metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'}, ) + vllm_group_port: int = field( + default=51216, + metadata={"help": "Port for the vLLM weight-update group (NCCL communicator)."}, + ) vllm_gpu_memory_utilization: float = field( default=0.9, metadata={ @@ -318,6 +339,16 @@ class GOLDConfig(SFTConfig): default=1, metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'}, ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": 'Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model\'s maximum context length.' + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={"help": 'Model implementation backend to use in vLLM. Use "vllm" (default) or "transformers".'}, + ) vllm_structured_outputs_regex: str | None = field( default=None, metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."}, diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 26c0f72dddd..352a6bf5cd0 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random import textwrap import warnings @@ -29,9 +28,8 @@ from accelerate import PartialState from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader -from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState, is_bitsandbytes_available +from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -51,7 +49,7 @@ from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset, truncate_dataset from ...extras.profiling import profiling_decorator -from ...generation.vllm_client import VLLMClient +from ...generation.vllm_generation import VLLMGeneration from ...import_utils import is_vllm_available from ...models import prepare_deepspeed from ...models.utils import unwrap_model_for_generation @@ -60,7 +58,6 @@ RepeatSampler, create_model_from_path, disable_dropout_in_model, - ensure_master_addr_port, pad, split_tensor_dict, ) @@ -74,9 +71,6 @@ if is_wandb_available(): import wandb -if is_vllm_available(): - from vllm import LLM, SamplingParams - from vllm.sampling_params import StructuredOutputsParams if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss @@ -87,9 +81,6 @@ from rich.table import Table from rich.text import Text -if is_bitsandbytes_available(): - import bitsandbytes as bnb - def print_prompt_completions_sample_uld( prompts: list[str], @@ -751,25 +742,6 @@ def _get_start_and_size_answers(self, answer_tensors): return answers_index, answers_size -class GOLDVLLMSyncCallback(TrainerCallback): - """Sync the model weights to vLLM after training steps when it's safe to do so.""" - - def __init__(self, trainer): - self.trainer = trainer - - def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): - """Sync weights after training step when DeepSpeed is stable.""" - if ( - self.trainer.use_vllm - and state.global_step != self.trainer._last_vllm_sync_step - and state.global_step % self.trainer.vllm_sync_frequency == 0 - ): - # Check if this is a step where gradients are synchronized - # This happens at the end of gradient accumulation cycles - if hasattr(self.trainer.accelerator, "sync_gradients") and self.trainer.accelerator.sync_gradients: - self.trainer._move_model_to_vllm() - self.trainer._last_vllm_sync_step = state.global_step - class GOLDTrainer(SFTTrainer): _tag_names = ["trl", "gold"] @@ -964,86 +936,35 @@ def __init__( "vLLM is not available and use_vllm is set to True. Please install vLLM with " "`pip install vllm` to use it." ) - self.vllm_mode = args.vllm_mode - self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size - self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization - self.vllm_enable_sleep_mode = args.vllm_enable_sleep_mode - if self.vllm_mode == "server": - if self.accelerator.is_main_process: - self.vllm_client = VLLMClient( - host=args.vllm_server_host, - server_port=args.vllm_server_port, - connection_timeout=args.vllm_server_timeout, - ) - self.vllm_client.init_communicator() - elif self.vllm_mode == "colocate": - student_model_name_or_path = self.model_name_or_path - - # Make sure tensor_parallel_size divides world size evenly - if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: - raise ValueError( - f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " - f"({self.accelerator.num_processes}) evenly." - ) - - if self.vllm_tensor_parallel_size > 1: - # Create subgroups of ranks for TP - self.vllm_tp_group, _ = torch.distributed.new_subgroups_by_enumeration( - [ - list( - range( - i * self.vllm_tensor_parallel_size, - (i + 1) * self.vllm_tensor_parallel_size, - ) - ) - for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) - ] - ) - - # vLLM requires the environment variables to be set for distributed training. - os.environ["RANK"] = str(self.accelerator.process_index) - os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) - os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) - ensure_master_addr_port() - - vllm_quantization = None - if is_bitsandbytes_available(): - for _, module in model.named_modules(): - if isinstance(module, bnb.nn.Linear4bit): - vllm_quantization = "bitsandbytes" - break - elif isinstance(module, bnb.nn.Linear8bitLt): - raise ValueError("vLLM does not support in-flight 8-bit quantization.") - - self.vllm_engine = LLM( - model=student_model_name_or_path, - revision=self.model_revision, - tensor_parallel_size=self.vllm_tensor_parallel_size, - gpu_memory_utilization=self.vllm_gpu_memory_utilization, - max_num_seqs=self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps, - max_model_len=args.max_length, - distributed_executor_backend="external_launcher", - # Feed identical seed for tp groups to ensure sampling results are the same across workers - seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - enable_sleep_mode=self.vllm_enable_sleep_mode, - quantization=vllm_quantization, - ) - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - - # When using vLLM, the main process is responsible for loading the model weights. This can cause process - # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we - # synchronize all processes after vLLM has been fully initialized. - self.accelerator.wait_for_everyone() - else: - raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + repetition_penalty=getattr(args, "repetition_penalty", 1.0), + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + min_p=getattr(args, "min_p", 0.0), + max_completion_length=args.max_completion_length, + logprobs=None, + ) self.vllm_sync_frequency = args.vllm_sync_frequency self._last_vllm_sync_step = -1 - self.add_callback(GOLDVLLMSyncCallback(self)) - def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() required_columns = [ @@ -1239,209 +1160,39 @@ def _generate_on_policy_for_slices( local_prompts.append(prompt) local_slice_indices.append(slice_idx) - prompts_text_for_vllm = self.processing_class.batch_decode( - torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), - skip_special_tokens=True, - ) - if self.processing_class.pad_token: - prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] - prompts_text_with_special = self.processing_class.batch_decode( torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), skip_special_tokens=False, ) - if self.use_vllm: - self._wake_vllm_if_needed() - - max_completion_length = self.generation_config.max_new_tokens - temperature = self.generation_config.temperature - top_k = ( - self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 - ) - top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 - repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 - min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 - - if self.use_vllm and self.vllm_mode == "server": - completion_ids = self._generate_vllm_server_global( - prompts_text_for_vllm, - max_completion_length, - temperature, - top_k, - top_p, - repetition_penalty, - min_p, - n=self.num_generations, - ) - elif self.use_vllm and self.vllm_mode == "colocate": - completion_ids = self._generate_vllm_colocate( - prompts_text_for_vllm, - max_completion_length, - temperature, - top_k, - top_p, - repetition_penalty, - min_p, - n=self.num_generations, - ) - else: + if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) return + if ( + self.state.global_step != self._last_vllm_sync_step + and self.state.global_step % self.vllm_sync_frequency == 0 + ): + self.vllm_generation.sync_weights() + self._last_vllm_sync_step = self.state.global_step + + prompt_ids_list = [p.tolist() for p in local_prompts] + _, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompt_ids_list, + images=None, + num_generations=self.num_generations, + ) + self._process_completions_to_buffer( slices, on_policy_indices, local_slice_indices, completion_ids, - prompts_text_for_vllm, prompts_text_with_special, - max_completion_length, - ) - - @staticmethod - def _deduplicate_prompts( - prompts: list[str], num_generations: int - ) -> tuple[list[str], list[tuple[int, int]]] | None: - """Deduplicate prompts and build a completion remapping.""" - seen: dict[str, list[int]] = {} - unique_prompts: list[str] = [] - dedup_mapping: list[tuple[int, int]] = [] - - for prompt in prompts: - if prompt not in seen: - seen[prompt] = [len(unique_prompts), 0] - unique_prompts.append(prompt) - entry = seen[prompt] - if entry[1] >= num_generations: - return None - dedup_mapping.append((entry[0], entry[1])) - entry[1] += 1 - - return unique_prompts, dedup_mapping - - def _generate_vllm_server_global( - self, - prompts_text: list[str], - max_tokens: int, - temperature: float, - top_k: int, - top_p: float, - repetition_penalty: float, - min_p: float, - n: int = 1, - ) -> list: - all_prompts_text = gather_object(prompts_text) - local_count = len(prompts_text) - - if self.accelerator.is_main_process: - if all_prompts_text: - dedup_mapping = None - if n > 1: - dedup_result = self._deduplicate_prompts(all_prompts_text, n) - if dedup_result is not None: - gen_prompts, dedup_mapping = dedup_result - gen_n = n - else: - gen_prompts = all_prompts_text - gen_n = 1 - else: - gen_prompts = all_prompts_text - gen_n = 1 - - completion_ids = self.vllm_client.generate( - prompts=gen_prompts, - n=gen_n, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_tokens, - structured_outputs_regex=self.vllm_structured_outputs_regex, - )["completion_ids"] - - if dedup_mapping is not None: - completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] - else: - completion_ids = [] - else: - completion_ids = [None] * len(all_prompts_text) if all_prompts_text else [] - - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * local_count, - (self.accelerator.process_index + 1) * local_count, - ) - return completion_ids[process_slice] - - def _generate_vllm_colocate( - self, - prompts_text: list[str], - max_tokens: int, - temperature: float, - top_k: int, - top_p: float, - repetition_penalty: float, - min_p: float, - n: int = 1, - ) -> list: - if self.vllm_structured_outputs_regex: - structured_outputs = StructuredOutputsParams(backend="outlines", regex=self.vllm_structured_outputs_regex) - else: - structured_outputs = None - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.vllm_tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts_text = prompts_text - - dedup_mapping = None - if n > 1 and all_prompts_text: - dedup_result = self._deduplicate_prompts(all_prompts_text, n) - if dedup_result is not None: - gen_prompts, dedup_mapping = dedup_result - gen_n = n - else: - gen_prompts = all_prompts_text - gen_n = 1 - else: - gen_prompts = all_prompts_text - gen_n = 1 - - sampling_params = SamplingParams( - n=gen_n, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_tokens, - structured_outputs=structured_outputs, + prompts_text_with_special, + self.generation_config.max_new_tokens, ) - if gen_prompts: - all_outputs = self.vllm_engine.generate(gen_prompts, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - else: - completion_ids = [] - - if dedup_mapping is not None: - completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - - return completion_ids - def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): """Fallback generation without vLLM (uses model.generate per slice).""" with unwrap_model_for_generation( @@ -2160,108 +1911,6 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token return new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): - """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM.""" - if visited is None: - visited = set() - - for child_name, child_module in module.named_children(): - child_prefix = f"{prefix}.{child_name}" if prefix else child_name - # recurse into the child - self._sync_fsdp_params_to_vllm(child_module, prefix=child_prefix, visited=visited) - - if isinstance(module, FSDP): - with FSDP.summon_full_params(module, recurse=False, writeback=False): - for param_name, param in module.named_parameters(): - full_name = f"{prefix}.{param_name}" if prefix else param_name - for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."): - full_name = full_name.replace(extra, "") - - if full_name in visited: - continue # skip FSDP subtrees already traversed - visited.add(full_name) - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(full_name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(full_name, param.data)]) - - def _move_model_to_vllm(self): - """Synchronize student model weights to vLLM engine.""" - # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - - if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: - empty_cache() - self.vllm_engine.wake_up(tags=["weights"]) - # Work around for https://github.com/vllm-project/vllm/issues/29341 - self.vllm_engine.collective_rpc("reload_weights") - - if is_peft_model(self.model): - # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as - # merging adapters in a sharded manner is not supported. - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext - # Update vLLM weights while parameters are gathered - # For PEFT with FSDP we need to use the memory efficient post-order traversal - self._sync_fsdp_params_to_vllm(self.model) - else: - # DeepSpeed ZeRO-3 with PEFT - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = name.removeprefix("base_model.model.").replace(".base_layer", "") - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = name.replace("modules_to_save.default.", "") - - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather (if needed) and update each parameter individually. - if self.is_fsdp_enabled: - # use memory-efficient post-order traversal for FSDP - self._sync_fsdp_params_to_vllm(self.model) - else: - # For DeepSpeed ZeRO-3, gather each parameter individually like GRPO trainer - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == "colocate": - llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model - llm_model.load_weights([(name, param.data)]) - - # Reset cache on vLLM - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - elif self.vllm_mode == "colocate": - self.vllm_engine.reset_prefix_cache() - - def _wake_vllm_if_needed(self): - if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: - empty_cache() - self.vllm_engine.wake_up(tags=["kv_cache"]) - def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): if not self.use_liger_gkd_loss: return nullcontext() From 1797fc14e364479f4ecb5c258002acf3c6e70139 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 19 Mar 2026 11:22:37 +0100 Subject: [PATCH 2/9] Update with precommit --- trl/experimental/gold/gold_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 352a6bf5cd0..0c7d3ae8bbe 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -26,10 +26,10 @@ import torch.nn as nn import torch.nn.functional as F from accelerate import PartialState -from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model +from accelerate.utils import DistributedType, broadcast_object_list, gather_object from datasets import Dataset, IterableDataset from torch.utils.data import DataLoader -from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState +from transformers import AutoTokenizer, TrainerCallback from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.generation.configuration_utils import GenerationConfig @@ -742,7 +742,6 @@ def _get_start_and_size_answers(self, answer_tensors): return answers_index, answers_size - class GOLDTrainer(SFTTrainer): _tag_names = ["trl", "gold"] _name = "GOLD" From b629987a0dbc24ffbacef7ca331c5d46ba626e36 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 14:55:53 +0100 Subject: [PATCH 3/9] Fix how we handle padding and special tokens --- trl/experimental/gold/gold_trainer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 0c7d3ae8bbe..d41a146661c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1159,11 +1159,20 @@ def _generate_on_policy_for_slices( local_prompts.append(prompt) local_slice_indices.append(slice_idx) + stacked_prompts = torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long) + prompts_text_with_special = self.processing_class.batch_decode( - torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + stacked_prompts, skip_special_tokens=False, ) + prompts_text = self.processing_class.batch_decode( + stacked_prompts, + skip_special_tokens=True, + ) + if self.processing_class.pad_token: + prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text] + if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) return @@ -1175,7 +1184,10 @@ def _generate_on_policy_for_slices( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step - prompt_ids_list = [p.tolist() for p in local_prompts] + pad_token_id = self.processing_class.pad_token_id + prompt_ids_list = [ + [tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts + ] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, @@ -1187,7 +1199,7 @@ def _generate_on_policy_for_slices( on_policy_indices, local_slice_indices, completion_ids, - prompts_text_with_special, + prompts_text, prompts_text_with_special, self.generation_config.max_new_tokens, ) From cdc31965e71d8a1c3a3e4a1a980115cff0c6797a Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:19:35 +0100 Subject: [PATCH 4/9] Address concern about vllm weight sync --- trl/experimental/gold/gold_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index d41a146661c..719dec860e9 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -962,7 +962,7 @@ def __init__( logprobs=None, ) self.vllm_sync_frequency = args.vllm_sync_frequency - self._last_vllm_sync_step = -1 + self._last_vllm_sync_step = -self.vllm_sync_frequency def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() @@ -1179,7 +1179,7 @@ def _generate_on_policy_for_slices( if ( self.state.global_step != self._last_vllm_sync_step - and self.state.global_step % self.vllm_sync_frequency == 0 + and self.state.global_step >= self._last_vllm_sync_step + self.vllm_sync_frequency ): self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step From f4c193e36f940ea33b3e348492db41eda5a64c64 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:20:09 +0100 Subject: [PATCH 5/9] Run precommit --- trl/experimental/gold/gold_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 719dec860e9..d76056ffad4 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1185,9 +1185,7 @@ def _generate_on_policy_for_slices( self._last_vllm_sync_step = self.state.global_step pad_token_id = self.processing_class.pad_token_id - prompt_ids_list = [ - [tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts - ] + prompt_ids_list = [[tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, From 2b41f84603274c2e9619777cc7187f3c388e3229 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:32:35 +0100 Subject: [PATCH 6/9] Fix max len behavior for generation --- trl/experimental/gold/gold_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index d76056ffad4..49f9fdcb07c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -949,7 +949,7 @@ def __init__( server_timeout=args.vllm_server_timeout, tensor_parallel_size=args.vllm_tensor_parallel_size, gpu_memory_utilization=args.vllm_gpu_memory_utilization, - max_model_length=args.vllm_max_model_length, + max_model_length=args.vllm_max_model_length or args.max_length, max_num_seqs=args.per_device_train_batch_size * args.gradient_accumulation_steps, enable_sleep_mode=args.vllm_enable_sleep_mode, model_impl=args.vllm_model_impl, From 91715cb5285274c025fc2a8e9096988215192451 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Mar 2026 15:33:17 +0100 Subject: [PATCH 7/9] Format docstring --- trl/experimental/gold/gold_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index c32ef45a3e8..1af9eeae332 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -97,8 +97,8 @@ class GOLDConfig(SFTConfig): Base URL for the vLLM server (e.g., `"http://localhost:8001"`). If provided, `vllm_server_host` and `vllm_server_port` are ignored. vllm_group_port (`int`, *optional*, defaults to `51216`): - Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need - to change it. + Port for the vLLM weight-update group (NCCL communicator). Unless the port is occupied, there is no need to + change it. vllm_max_model_length (`int`, *optional*): Maximum model sequence length for the colocated vLLM engine when `vllm_mode="colocate"`. Defaults to the model's maximum context length. From b94fc1fd36f350119ea915507c3eb51102c3a3c7 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 16:08:17 +0100 Subject: [PATCH 8/9] Remove decode -> re-tokenization roundtrip --- tests/experimental/test_gold_trainer.py | 319 ++++++++++++++++++++---- trl/experimental/gold/gold_trainer.py | 100 +++++--- 2 files changed, 333 insertions(+), 86 deletions(-) diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index 50800c0a136..d7e32056323 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -19,6 +19,7 @@ from datasets import load_dataset from transformers import AutoTokenizer +from trl.experimental.gold import gold_trainer as gold_trainer_module from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts from trl.experimental.utils import DataCollatorForChatML @@ -289,58 +290,11 @@ def pad_labels(labels, target_length): return labels + [-100] * (target_length - len(labels)) -def test_process_completions_to_buffer_left_pads_prompt_retokenization(): - class DummyBatch: - def __init__(self, input_ids): - self.input_ids = input_ids - - def to(self, device): - self.input_ids = self.input_ids.to(device) - return self - +def test_process_completions_to_buffer_left_pads_prompt_ids(): class RecordingTokenizer: pad_token_id = 0 pad_token = "" - def __init__(self): - self.padding_side = "right" - self.calls = [] - self._prompt_ids = { - "short": [11], - "longer": [21, 22], - } - - def __call__( - self, - texts, - return_tensors, - padding, - truncation, - max_length, - add_special_tokens, - padding_side=None, - ): - assert return_tensors == "pt" - assert padding == "longest" - assert not truncation - assert max_length is None - assert not add_special_tokens - self.calls.append(padding_side) - - side = padding_side or self.padding_side - encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts] - max_len = max(len(ids) for ids in encoded) - - padded = [] - for ids in encoded: - pad_width = max_len - len(ids) - if pad_width: - pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long) - ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad]) - padded.append(ids) - - return DummyBatch(torch.stack(padded)) - def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): del skip_special_tokens, clean_up_tokenization_spaces return [" ".join(str(token) for token in sequence) for sequence in sequences] @@ -358,19 +312,282 @@ def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenizati on_policy_indices=[0], local_slice_indices=[0, 0], completion_ids=[[31], [41]], - prompts_text=["short", "longer"], prompts_text_with_special=["short", "longer"], + prompt_ids_list=[[11], [21, 22]], + prompts_text=["short", "longer"], max_completion_length=1, ) buffered_inputs = trainer._buffered_inputs[0] - assert trainer.processing_class.calls == ["left"] - assert trainer.processing_class.padding_side == "right" assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long)) assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long)) assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]])) +def test_generate_on_policy_for_slices_uses_prompt_attention_mask_for_vllm_prompts(): + class RecordingVLLMGeneration: + def __init__(self): + self.prompts = None + self.sync_calls = 0 + + def sync_weights(self): + self.sync_calls += 1 + + def generate(self, prompts, images, num_generations): + self.prompts = prompts + assert images is None + assert num_generations == 1 + return None, [[42]], None, None + + class RecordingTokenizer: + pad_token_id = 9 + pad_token = "" + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del clean_up_tokenization_spaces + decoded = [] + token_map = {5: "A", 6: "B", 9: ""} + for sequence in sequences: + tokens = [] + for token in sequence: + token = int(token) + if skip_special_tokens and token == 9: + continue + tokens.append(token_map[token]) + decoded.append(" ".join(tokens)) + return decoded + + captured = {} + + def capture_process_completions( + slices, + on_policy_indices, + local_slice_indices, + completion_ids, + prompt_ids_list, + prompts_text_with_special, + prompts_text, + max_completion_length, + ): + captured["slices"] = slices + captured["on_policy_indices"] = on_policy_indices + captured["local_slice_indices"] = local_slice_indices + captured["completion_ids"] = completion_ids + captured["prompt_ids_list"] = prompt_ids_list + captured["prompts_text"] = prompts_text + captured["prompts_text_with_special"] = prompts_text_with_special + captured["max_completion_length"] = max_completion_length + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(is_main_process=True) + trainer.args = SimpleNamespace(report_to=[]) + trainer.processing_class = RecordingTokenizer() + trainer.use_vllm = True + trainer.vllm_generation = RecordingVLLMGeneration() + trainer.vllm_sync_frequency = 1 + trainer._last_vllm_sync_step = -1 + trainer.state = SimpleNamespace(global_step=0) + trainer.num_generations = 1 + trainer.generation_config = SimpleNamespace(max_new_tokens=1) + trainer._process_completions_to_buffer = capture_process_completions + + slices = [ + { + "prompts": torch.tensor([[9, 9, 5, 9, 6]], dtype=torch.long), + "prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long), + } + ] + + GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0]) + + assert trainer.vllm_generation.prompts == [[5, 9, 6]] + assert trainer.vllm_generation.sync_calls == 1 + assert captured["completion_ids"] == [[42]] + assert captured["prompt_ids_list"] == [[5, 9, 6]] + assert captured["prompts_text"] == ["A B"] + assert captured["prompts_text_with_special"] == ["A B"] + + +def test_generate_on_policy_for_slices_reconstructs_prompt_with_special_tokens(): + class RecordingVLLMGeneration: + def __init__(self): + self.prompts = None + self.sync_calls = 0 + + def sync_weights(self): + self.sync_calls += 1 + + def generate(self, prompts, images, num_generations): + self.prompts = prompts + assert images is None + assert num_generations == 1 + return None, [[42]], None, None + + class RecordingTokenizer: + pad_token_id = 0 + pad_token = "" + + def __init__(self): + self.truncation_side = "right" + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del clean_up_tokenization_spaces + token_map = {0: "", 5: "A", 6: "B", 13: "", 42: "C"} + decoded = [] + for sequence in sequences: + tokens = [] + for token in sequence: + token = int(token) + if skip_special_tokens and token == 13: + continue + if token == 0: + continue + tokens.append(token_map[token]) + decoded.append(" ".join(tokens)) + return decoded + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(device=torch.device("cpu"), is_main_process=True) + trainer.processing_class = RecordingTokenizer() + trainer.args = SimpleNamespace(max_length=None, report_to=[]) + trainer.use_vllm = True + trainer.vllm_generation = RecordingVLLMGeneration() + trainer.vllm_sync_frequency = 1 + trainer._last_vllm_sync_step = -1 + trainer.state = SimpleNamespace(global_step=0) + trainer.num_generations = 1 + trainer.generation_config = SimpleNamespace(max_new_tokens=1) + trainer._buffered_inputs = [None] + trainer._buffered_text_logs = [None] + + slices = [ + { + "slice": "original", + "prompts": torch.tensor([[0, 0, 5, 13, 6]], dtype=torch.long), + "prompt_attention_mask": torch.tensor([[0, 0, 1, 1, 1]], dtype=torch.long), + } + ] + + GOLDTrainer._generate_on_policy_for_slices(trainer, slices, [0]) + + buffered_inputs = trainer._buffered_inputs[0] + assert trainer.vllm_generation.prompts == [[5, 13, 6]] + assert trainer.vllm_generation.sync_calls == 1 + assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[5, 13, 6, 42]], dtype=torch.long)) + assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[1, 1, 1, 1]], dtype=torch.long)) + assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, -100, 42]], dtype=torch.long)) + assert buffered_inputs["original_prompt_text"] == ["A B"] + assert buffered_inputs["original_completion_text"] == ["C"] + assert trainer._buffered_text_logs[0] == (["A B"], ["C"]) + + +def test_gold_trainer_init_defaults_vllm_max_model_length_to_max_length(monkeypatch): + captured = {} + + class DummyStudentModel: + def __init__(self): + self.config = SimpleNamespace(_name_or_path="student", vocab_size=17) + self.generation_config = SimpleNamespace(eos_token_id=2) + self.name_or_path = "student" + + class DummyTeacherModel: + def __init__(self): + self.resized_to = None + + def resize_token_embeddings(self, vocab_size): + self.resized_to = vocab_size + + class DummyProcessingClass: + pad_token_id = 0 + + def fake_sft_init( + self, + model, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + compute_metrics=None, + callbacks=None, + optimizers=None, + preprocess_logits_for_metrics=None, + peft_config=None, + ): + del data_collator, train_dataset, eval_dataset, compute_metrics, callbacks, optimizers + del preprocess_logits_for_metrics, peft_config + self.model = model + self.args = args + self.processing_class = processing_class + self.accelerator = SimpleNamespace( + device=torch.device("cpu"), + num_processes=1, + prepare_model=lambda module, evaluation_mode=True: module, + ) + self.is_deepspeed_enabled = False + self.is_fsdp_enabled = False + + class CapturingVLLMGeneration: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(gold_trainer_module.SFTTrainer, "__init__", fake_sft_init) + monkeypatch.setattr(gold_trainer_module, "is_vllm_available", lambda: True) + monkeypatch.setattr(gold_trainer_module, "VLLMGeneration", CapturingVLLMGeneration) + + args = SimpleNamespace( + model_init_kwargs=None, + max_length=128, + use_liger_kernel=False, + teacher_model_init_kwargs=None, + use_uld_loss=False, + teacher_tokenizer_name_or_path=None, + teacher_model_revision=None, + disable_dropout=False, + lmbda=1.0, + beta=0.5, + temperature=1.0, + top_p=1.0, + seq_kd=False, + num_generations=1, + use_transformers_paged=False, + max_completion_length=16, + top_k=0, + log_completions=False, + log_completions_steps=100, + wandb_log_unique_prompts=True, + num_completions_to_print=None, + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + use_vllm=True, + vllm_mode="colocate", + vllm_structured_outputs_regex=None, + vllm_server_base_url=None, + vllm_server_host="0.0.0.0", + vllm_server_port=8001, + vllm_group_port=51216, + vllm_server_timeout=240.0, + vllm_tensor_parallel_size=1, + vllm_gpu_memory_utilization=0.2, + vllm_max_model_length=None, + vllm_enable_sleep_mode=False, + vllm_model_impl="vllm", + vllm_sync_frequency=1, + ) + + teacher_model = DummyTeacherModel() + GOLDTrainer( + model=DummyStudentModel(), + teacher_model=teacher_model, + args=args, + data_collator=object(), + processing_class=DummyProcessingClass(), + ) + + assert teacher_model.resized_to == 17 + assert captured["max_model_length"] == 128 + + def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer): config = build_config() loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 49f9fdcb07c..39462fbaa1e 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1095,20 +1095,26 @@ def _ensure_original_text_fields( @staticmethod def _build_sequence_batch( - new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + new_input_ids: torch.Tensor, + prompt_lengths: torch.Tensor, + pad_token_id: int | None, + attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Build attention mask and labels from full sequences and prompt lengths.""" prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) completion_mask = positions >= prompt_lengths.unsqueeze(1) - new_attention_mask = torch.ones_like(new_input_ids) - if pad_token_id is not None: - new_attention_mask[new_input_ids == pad_token_id] = 0 + if attention_mask is not None: + new_attention_mask = attention_mask.to(device=new_input_ids.device, dtype=new_input_ids.dtype) + else: + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 new_labels = torch.full_like(new_input_ids, -100) - new_labels[completion_mask] = new_input_ids[completion_mask] - if pad_token_id is not None: + new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[completion_mask & new_attention_mask.bool()] + if attention_mask is None and pad_token_id is not None: new_labels[new_input_ids == pad_token_id] = -100 return new_attention_mask, new_labels @@ -1151,27 +1157,25 @@ def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_s def _generate_on_policy_for_slices( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] ): - local_prompts = [] + prompt_ids_list = [] local_slice_indices = [] for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] - for prompt in slice_inputs["prompts"]: - local_prompts.append(prompt) + prompt_attention_mask = slice_inputs.get("prompt_attention_mask") + for prompt_idx, prompt in enumerate(slice_inputs["prompts"]): + if prompt_attention_mask is not None: + prompt = prompt[prompt_attention_mask[prompt_idx].bool()] + prompt_ids_list.append(prompt.tolist()) local_slice_indices.append(slice_idx) - stacked_prompts = torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long) - - prompts_text_with_special = self.processing_class.batch_decode( - stacked_prompts, - skip_special_tokens=False, - ) - prompts_text = self.processing_class.batch_decode( - stacked_prompts, + prompt_ids_list, skip_special_tokens=True, ) - if self.processing_class.pad_token: - prompts_text = [p.replace(self.processing_class.pad_token, "") for p in prompts_text] + prompts_text_with_special = self.processing_class.batch_decode( + prompt_ids_list, + skip_special_tokens=False, + ) if not self.use_vllm: self._generate_non_vllm_for_slices(slices, on_policy_indices) @@ -1184,8 +1188,6 @@ def _generate_on_policy_for_slices( self.vllm_generation.sync_weights() self._last_vllm_sync_step = self.state.global_step - pad_token_id = self.processing_class.pad_token_id - prompt_ids_list = [[tok for tok in p.tolist() if tok != pad_token_id] for p in local_prompts] _, completion_ids, _, _ = self.vllm_generation.generate( prompts=prompt_ids_list, images=None, @@ -1197,8 +1199,9 @@ def _generate_on_policy_for_slices( on_policy_indices, local_slice_indices, completion_ids, - prompts_text, + prompt_ids_list, prompts_text_with_special, + prompts_text, self.generation_config.max_new_tokens, ) @@ -1235,8 +1238,9 @@ def _process_completions_to_buffer( on_policy_indices: list[int], local_slice_indices: list[int], completion_ids: list, - prompts_text: list[str], + prompt_ids_list: list[list[int]], prompts_text_with_special: list[str], + prompts_text: list[str], max_completion_length: int, ): """ @@ -1246,40 +1250,50 @@ def _process_completions_to_buffer( pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompt_ids = {idx: [] for idx in on_policy_indices} slice_prompts = {idx: [] for idx in on_policy_indices} slice_prompts_special = {idx: [] for idx in on_policy_indices} for i, slice_idx in enumerate(local_slice_indices): slice_completions[slice_idx].append(completion_ids[i]) - slice_prompts[slice_idx].append(prompts_text[i]) + slice_prompt_ids[slice_idx].append(prompt_ids_list[i]) slice_prompts_special[slice_idx].append(prompts_text_with_special[i]) + slice_prompts[slice_idx].append(prompts_text[i]) for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] completion_ids_for_slice = slice_completions[slice_idx] + prompt_ids_for_slice = slice_prompt_ids[slice_idx] prompt_txts = slice_prompts[slice_idx] prompt_txts_with_special = slice_prompts_special[slice_idx] prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None - prompt_tokenized = self.processing_class( - prompt_txts, - return_tensors="pt", - padding="longest", - padding_side="left", - truncation=True if prompt_max_length else False, - max_length=prompt_max_length, - add_special_tokens=False, - ).to(device) - prompt_ids = prompt_tokenized.input_ids + truncated_prompt_ids = [] + prompt_attention_masks = [] + truncation_side = getattr(self.processing_class, "truncation_side", "right") + for prompt_ids in prompt_ids_for_slice: + if prompt_max_length and len(prompt_ids) > prompt_max_length: + if truncation_side == "left": + prompt_ids = prompt_ids[-prompt_max_length:] + else: + prompt_ids = prompt_ids[:prompt_max_length] + prompt_tensor = torch.tensor(prompt_ids, device=device, dtype=torch.long) + truncated_prompt_ids.append(prompt_tensor) + prompt_attention_masks.append(torch.ones(len(prompt_ids), device=device, dtype=torch.long)) + + prompt_ids = pad(truncated_prompt_ids, padding_side="left", padding_value=pad_token_id) + prompt_attention_mask = pad(prompt_attention_masks, padding_side="left", padding_value=0) completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids_for_slice] completion_ids_for_text: list[list[int]] = [] padded_completion_ids_list = [] + completion_attention_masks = [] for completion_tensor in completion_ids_tensors: if len(completion_tensor) > max_completion_length: truncated_completion_tensor = completion_tensor[:max_completion_length] padded_completion_ids_list.append(truncated_completion_tensor) completion_ids_for_text.append(truncated_completion_tensor.tolist()) + completion_attention_masks.append(torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long)) elif len(completion_tensor) < max_completion_length: padding_needed = max_completion_length - len(completion_tensor) padded_tensor = torch.cat( @@ -1295,15 +1309,31 @@ def _process_completions_to_buffer( ) padded_completion_ids_list.append(padded_tensor) completion_ids_for_text.append(completion_tensor.tolist()) + completion_attention_masks.append( + torch.cat( + [ + torch.ones(len(completion_tensor), device=device, dtype=torch.long), + torch.zeros(padding_needed, device=device, dtype=torch.long), + ] + ) + ) else: padded_completion_ids_list.append(completion_tensor) completion_ids_for_text.append(completion_tensor.tolist()) + completion_attention_masks.append(torch.ones(len(completion_tensor), device=device, dtype=torch.long)) completion_ids_padded = torch.stack(padded_completion_ids_list) + completion_attention_mask = torch.stack(completion_attention_masks) new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + new_attention_mask = torch.cat([prompt_attention_mask, completion_attention_mask], dim=1) prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) - new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + new_attention_mask, new_labels = self._build_sequence_batch( + new_input_ids, + prompt_lengths, + pad_token_id, + attention_mask=new_attention_mask, + ) completion_texts = self.processing_class.batch_decode( completion_ids_for_text, From dcfce594fc5fea7bf3faee793cd1fc7296455e22 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 24 Mar 2026 16:10:11 +0100 Subject: [PATCH 9/9] Run precommit --- trl/experimental/gold/gold_trainer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 39462fbaa1e..bfabbd7e1bb 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1113,7 +1113,9 @@ def _build_sequence_batch( new_attention_mask[new_input_ids == pad_token_id] = 0 new_labels = torch.full_like(new_input_ids, -100) - new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[completion_mask & new_attention_mask.bool()] + new_labels[completion_mask & new_attention_mask.bool()] = new_input_ids[ + completion_mask & new_attention_mask.bool() + ] if attention_mask is None and pad_token_id is not None: new_labels[new_input_ids == pad_token_id] = -100 @@ -1293,7 +1295,9 @@ def _process_completions_to_buffer( truncated_completion_tensor = completion_tensor[:max_completion_length] padded_completion_ids_list.append(truncated_completion_tensor) completion_ids_for_text.append(truncated_completion_tensor.tolist()) - completion_attention_masks.append(torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long)) + completion_attention_masks.append( + torch.ones(len(truncated_completion_tensor), device=device, dtype=torch.long) + ) elif len(completion_tensor) < max_completion_length: padding_needed = max_completion_length - len(completion_tensor) padded_tensor = torch.cat( @@ -1320,7 +1324,9 @@ def _process_completions_to_buffer( else: padded_completion_ids_list.append(completion_tensor) completion_ids_for_text.append(completion_tensor.tolist()) - completion_attention_masks.append(torch.ones(len(completion_tensor), device=device, dtype=torch.long)) + completion_attention_masks.append( + torch.ones(len(completion_tensor), device=device, dtype=torch.long) + ) completion_ids_padded = torch.stack(padded_completion_ids_list) completion_attention_mask = torch.stack(completion_attention_masks)