From 719c644b1f97abcfc48db50a443cf92906d07f7c Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 18 Feb 2026 12:30:47 +0100 Subject: [PATCH 01/32] Implement buffer for GOLDTrainer --- docs/source/gold_trainer.md | 2 + trl/experimental/gold/gold_config.py | 48 +++ trl/experimental/gold/gold_trainer.py | 560 ++++++++++++++++++++++++-- 3 files changed, 576 insertions(+), 34 deletions(-) diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index dbe7e7b01e6..a95e7d084a9 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -29,6 +29,8 @@ messages). Important configuration flags on [`GOLDConfig`] include: matched/unmatched loss. * `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. +* `steps_per_generation`, `num_generations`, `generation_batch_size` – control buffered rollout generation across + gradient accumulation windows, including multi-generation sampling per prompt. A minimal end-to-end example: diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 827b639dec8..ea5e621afc8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Any @@ -53,6 +54,14 @@ class GOLDConfig(SFTConfig): seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). + steps_per_generation (`int` or `None`, *optional*, defaults to `None`): + Number of optimization steps per generation. If `None`, it defaults to + `gradient_accumulation_steps`. + num_generations (`int`, *optional*, defaults to `1`): + Number of generations per prompt. Each prompt is repeated this many times in the generation batch. + generation_batch_size (`int` or `None`, *optional*, defaults to `None`): + Number of prompts per generation batch (global, across all processes). If `None`, it is computed from + `per_device_train_batch_size * world_size * steps_per_generation`. use_uld_loss (`bool`, *optional*, defaults to `False`): Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss. @@ -184,6 +193,19 @@ class GOLDConfig(SFTConfig): "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." }, ) + num_generations: int = field( + default=1, + metadata={ + "help": "Number of generations per prompt. Each prompt is repeated this many times in the batch." + }, + ) + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Number of prompts per generation batch (global, across all processes). " + "If None, computed from per_device_train_batch_size * num_processes * steps_per_generation." + }, + ) # ULD Loss parameters use_uld_loss: bool = field( @@ -392,6 +414,32 @@ def __post_init__(self): if self.steps_per_generation is None: self.steps_per_generation = self.gradient_accumulation_steps + if self.generation_batch_size is None: + self.generation_batch_size = ( + self.per_device_train_batch_size * self.world_size * self.steps_per_generation + ) + + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + if self.generation_batch_size // self.num_generations < 1: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be at least num_generations " + f"({self.num_generations}) so that each generation batch contains at least one unique prompt." + ) + if self.num_generations > 1 and self.lmbda < 1.0: + warnings.warn( + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches will " + f"contain {self.num_generations} identical copies of each dataset sample. Consider setting " + f"lmbda=1.0 (fully on-policy) when using num_generations > 1.", + UserWarning, + stacklevel=2, + ) + # Validate ULD parameters if self.use_uld_loss: if self.uld_crossentropy_weight < 0.0: diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index ebe74b0d7f2..99256b362a8 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -19,8 +19,10 @@ from collections import defaultdict, deque from collections.abc import Callable from contextlib import nullcontext +from functools import partial from typing import Any, Optional +import datasets import torch import torch.distributed as dist import torch.nn as nn @@ -29,6 +31,7 @@ from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState, is_bitsandbytes_available from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin @@ -38,8 +41,9 @@ from transformers.modeling_utils import PreTrainedModel from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer_utils import EvalPrediction +from transformers.trainer_utils import EvalPrediction, seed_worker from transformers.utils import ( + is_datasets_available, is_flash_attn_2_available, is_liger_kernel_available, is_peft_available, @@ -54,11 +58,13 @@ from ...models.utils import unwrap_model_for_generation from ...trainer.sft_trainer import SFTTrainer from ...trainer.utils import ( + RepeatSampler, create_model_from_path, disable_dropout_in_model, empty_cache, ensure_master_addr_port, pad, + split_tensor_dict, ) from ..utils import DataCollatorForChatML from .gold_config import GOLDConfig @@ -887,6 +893,7 @@ def __init__( self.temperature = args.temperature self.top_p = args.top_p self.seq_kd = args.seq_kd + self.num_generations = args.num_generations # Track per-step loss statistics for on/off-policy batches (used in logging) self._on_policy_loss_total = 0.0 @@ -894,6 +901,12 @@ def __init__( self._on_policy_step_equiv = 0.0 self._off_policy_step_equiv = 0.0 + # Buffering for rollouts across gradient accumulation steps + self._buffered_inputs = None + self._buffered_on_policy = None + self._buffered_text_logs = None + self._step = 0 + # Hybrid ULD matched/unmatched accumulators (logged every step when ULD hybrid is used) self._matched_sum = 0.0 self._unmatched_sum = 0.0 @@ -1050,6 +1063,508 @@ def _set_signature_columns_if_needed(self): if column not in self._signature_columns: self._signature_columns.append(column) + def on_epoch_begin(self, args, state, control, **kwargs): + self._step = 0 + self._buffered_inputs = None + self._buffered_on_policy = None + self._buffered_text_logs = None + return control + + def _get_train_sampler(self, dataset=None): + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.args.steps_per_generation, + shuffle=True, + seed=self.args.seed, + ) + + def get_train_dataloader(self): + """ + Override Trainer.get_train_dataloader to load a generation batch covering one optimizer window. + + Instead of returning a standard per-step batch (i.e., `per_device_batch_size`), this dataloader loads + a batch of size `per_device_batch_size * steps_per_generation`. Combined with the `RepeatSampler` + (which inflates the sampler length by `steps_per_generation`), this prevents the Trainer from + double-dividing by `gradient_accumulation_steps` when computing optimizer steps per epoch. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if self.args.dataloader_num_workers > 0: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + if not self.model.training: + return generation_batch + + spg = self.args.steps_per_generation + if self._step % spg == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, spg) + + slice_idx = self._step % spg + inputs = self._buffered_inputs[slice_idx] + self._step += 1 + return inputs + + def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Tensor | Any]) -> list[str] | None: + """ + Decode completion-only text from labels for cross-tokenizer ULD when raw text is not available. + """ + labels = slice_inputs.get("labels") + if labels is None or not isinstance(labels, torch.Tensor): + return None + + labels_cpu = labels.detach().cpu() + decoded_completion_tokens: list[list[int]] = [] + for row in labels_cpu: + token_ids = row[row != -100].tolist() + if self.processing_class.pad_token_id is not None: + token_ids = [tok for tok in token_ids if tok != self.processing_class.pad_token_id] + decoded_completion_tokens.append(token_ids) + + return self.processing_class.batch_decode( + decoded_completion_tokens, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + """ + Ensure original prompt/completion text fields are available for ULD loss. + """ + if "original_prompt_text" in slice_inputs and "original_completion_text" in slice_inputs: + return slice_inputs + + prompts = slice_inputs.get("prompts") + if prompts is None or not isinstance(prompts, torch.Tensor): + return slice_inputs + + prompt_texts = self.processing_class.batch_decode( + prompts, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + completion_texts = self._decode_completion_texts_from_labels(slice_inputs) + if completion_texts is None: + return slice_inputs + + updated_slice = dict(slice_inputs) + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts + return updated_slice + + @profiling_decorator + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], spg: int): + slices = split_tensor_dict(generation_batch, spg) + + if self.accelerator.is_main_process: + on_policy_flags = [random.random() <= self.lmbda for _ in range(spg)] + else: + on_policy_flags = [False] * spg + + on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) + on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag] + + self._buffered_inputs = [None] * spg + self._buffered_on_policy = on_policy_flags + self._buffered_text_logs = [None] * spg + + for i, flag in enumerate(on_policy_flags): + if not flag: + slice_inputs = slices[i] + + if self.use_uld_loss and self.teacher_tokenizer is not None: + slice_inputs = self._ensure_original_text_fields(slice_inputs) + if "original_prompt_text" not in slice_inputs or "original_completion_text" not in slice_inputs: + raise ValueError( + "Off-policy batch missing 'original_prompt_text' or 'original_completion_text' fields. " + "When using ULD loss with cross-tokenizer alignment, datasets must be prepared with " + "_prepare_dataset_with_original_text(). Ensure your dataset includes these fields." + ) + + self._buffered_inputs[i] = slice_inputs + + if on_policy_indices: + self._generate_on_policy_for_slices(slices, on_policy_indices) + + @profiling_decorator + def _generate_on_policy_for_slices( + self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] + ): + local_prompts = [] + local_slice_info = [] + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + for j in range(slice_inputs["prompts"].shape[0]): + local_prompts.append(slice_inputs["prompts"][j]) + local_slice_info.append((slice_idx, j)) + + prompts_text_for_vllm = self.processing_class.batch_decode( + torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + skip_special_tokens=True, + ) + if self.processing_class.pad_token: + prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] + + prompts_text_with_special = self.processing_class.batch_decode( + torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + skip_special_tokens=False, + ) + + if self.use_vllm: + self._wake_vllm_if_needed() + + max_completion_length = self.generation_config.max_new_tokens + temperature = self.generation_config.temperature + top_k = self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 + top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 + repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 + min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 + + if self.use_vllm and self.vllm_mode == "server": + completion_ids = self._generate_vllm_server_global( + prompts_text_for_vllm, + max_completion_length, + temperature, + top_k, + top_p, + repetition_penalty, + min_p, + n=self.num_generations, + ) + elif self.use_vllm and self.vllm_mode == "colocate": + completion_ids = self._generate_vllm_colocate( + prompts_text_for_vllm, + max_completion_length, + temperature, + top_k, + top_p, + repetition_penalty, + min_p, + n=self.num_generations, + ) + else: + self._generate_non_vllm_for_slices(slices, on_policy_indices) + return + + self._process_completions_to_buffer( + slices, + on_policy_indices, + local_slice_info, + completion_ids, + prompts_text_for_vllm, + prompts_text_with_special, + max_completion_length, + ) + + @staticmethod + def _deduplicate_prompts( + prompts: list[str], num_generations: int + ) -> tuple[list[str], list[tuple[int, int]]] | None: + """ + Deduplicate repeated prompts and build a mapping for remapping completions. + + When ``num_generations > 1``, the ``RepeatSampler`` produces K copies of each prompt. + Instead of sending K copies to vLLM with ``n=1``, we send unique prompts with ``n=K``, + which is more efficient. + """ + seen: dict[str, list[int]] = {} + unique_prompts: list[str] = [] + dedup_mapping: list[tuple[int, int]] = [] + + for prompt in prompts: + if prompt not in seen: + seen[prompt] = [len(unique_prompts), 0] + unique_prompts.append(prompt) + entry = seen[prompt] + if entry[1] >= num_generations: + return None + dedup_mapping.append((entry[0], entry[1])) + entry[1] += 1 + + return unique_prompts, dedup_mapping + + def _generate_vllm_server_global( + self, + prompts_text: list[str], + max_tokens: int, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + min_p: float, + n: int = 1, + ) -> list: + all_prompts_text = gather_object(prompts_text) + local_count = len(prompts_text) + + if self.accelerator.is_main_process: + if all_prompts_text: + dedup_mapping = None + if n > 1: + dedup_result = self._deduplicate_prompts(all_prompts_text, n) + if dedup_result is not None: + gen_prompts, dedup_mapping = dedup_result + gen_n = n + else: + gen_prompts = all_prompts_text + gen_n = 1 + else: + gen_prompts = all_prompts_text + gen_n = 1 + + completion_ids = self.vllm_client.generate( + prompts=gen_prompts, + n=gen_n, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + structured_outputs_regex=self.vllm_structured_outputs_regex, + )["completion_ids"] + + if dedup_mapping is not None: + completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] + else: + completion_ids = [] + else: + completion_ids = [None] * len(all_prompts_text) if all_prompts_text else [] + + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * local_count, + (self.accelerator.process_index + 1) * local_count, + ) + return completion_ids[process_slice] + + def _generate_vllm_colocate( + self, + prompts_text: list[str], + max_tokens: int, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + min_p: float, + n: int = 1, + ) -> list: + if self.vllm_structured_outputs_regex: + structured_outputs = StructuredOutputsParams(backend="outlines", regex=self.vllm_structured_outputs_regex) + else: + structured_outputs = None + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + orig_size = len(prompts_text) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.vllm_tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts_text = prompts_text + + dedup_mapping = None + if n > 1 and all_prompts_text: + dedup_result = self._deduplicate_prompts(all_prompts_text, n) + if dedup_result is not None: + gen_prompts, dedup_mapping = dedup_result + gen_n = n + else: + gen_prompts = all_prompts_text + gen_n = 1 + else: + gen_prompts = all_prompts_text + gen_n = 1 + + sampling_params = SamplingParams( + n=gen_n, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + structured_outputs=structured_outputs, + ) + + if gen_prompts: + all_outputs = self.vllm_engine.generate(gen_prompts, sampling_params=sampling_params, use_tqdm=False) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + else: + completion_ids = [] + + if dedup_mapping is not None: + completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + + if self.vllm_enable_sleep_mode: + self.vllm_engine.sleep(level=2) + + return completion_ids + + def _generate_non_vllm_for_slices( + self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] + ): + """Fallback generation without vLLM (uses model.generate per slice).""" + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + with unwrap_model_for_generation( + self.model, + self.accelerator, + generation_kwargs=self.generation_kwargs, + ) as unwrapped_model: + result = self.generate_on_policy_outputs( + unwrapped_model, + slice_inputs, + self.generation_config, + self.processing_class.pad_token_id, + ) + new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + + updated_slice = dict(slice_inputs) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts + + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + + def _process_completions_to_buffer( + self, + slices: list[dict[str, torch.Tensor | Any]], + on_policy_indices: list[int], + local_slice_info: list[tuple[int, int]], + completion_ids: list, + prompts_text: list[str], + prompts_text_with_special: list[str], + max_completion_length: int, + ): + """ + Process vLLM completions and update buffered inputs for on-policy slices. + """ + device = self.accelerator.device + pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 + + slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompts = {idx: [] for idx in on_policy_indices} + slice_prompts_special = {idx: [] for idx in on_policy_indices} + + for i, (slice_idx, _) in enumerate(local_slice_info): + slice_completions[slice_idx].append(completion_ids[i]) + slice_prompts[slice_idx].append(prompts_text[i]) + slice_prompts_special[slice_idx].append(prompts_text_with_special[i]) + + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + completion_ids_for_slice = slice_completions[slice_idx] + prompt_txts = slice_prompts[slice_idx] + prompt_txts_with_special = slice_prompts_special[slice_idx] + + prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None + prompt_tokenized = self.processing_class( + prompt_txts, + return_tensors="pt", + padding="longest", + truncation=True if prompt_max_length else False, + max_length=prompt_max_length, + add_special_tokens=False, + ).to(device) + prompt_ids = prompt_tokenized.input_ids + + completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids_for_slice] + completion_ids_for_text: list[list[int]] = [] + padded_completion_ids_list = [] + for completion_tensor in completion_ids_tensors: + if len(completion_tensor) > max_completion_length: + truncated_completion_tensor = completion_tensor[:max_completion_length] + padded_completion_ids_list.append(truncated_completion_tensor) + completion_ids_for_text.append(truncated_completion_tensor.tolist()) + elif len(completion_tensor) < max_completion_length: + padding_needed = max_completion_length - len(completion_tensor) + padded_tensor = torch.cat( + [ + completion_tensor, + torch.full( + (padding_needed,), + pad_token_id, + device=device, + dtype=completion_tensor.dtype, + ), + ] + ) + padded_completion_ids_list.append(padded_tensor) + completion_ids_for_text.append(completion_tensor.tolist()) + else: + padded_completion_ids_list.append(completion_tensor) + completion_ids_for_text.append(completion_tensor.tolist()) + + completion_ids_padded = torch.stack(padded_completion_ids_list) + + new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + new_attention_mask = torch.ones_like(new_input_ids) + if self.processing_class.pad_token_id is not None: + new_attention_mask[new_input_ids == self.processing_class.pad_token_id] = 0 + + prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) + new_labels = torch.full_like(new_input_ids, -100) + for idx in range(new_input_ids.shape[0]): + length = int(prompt_lengths[idx].item()) + new_labels[idx, length:] = new_input_ids[idx, length:] + if self.processing_class.pad_token_id is not None: + new_labels[new_input_ids == self.processing_class.pad_token_id] = -100 + + completion_texts = self.processing_class.batch_decode( + completion_ids_for_text, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + updated_slice = dict(slice_inputs) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + updated_slice["original_prompt_text"] = prompt_txts_with_special + updated_slice["original_completion_text"] = completion_texts + + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_txts, completion_texts) + def _prepare_dataset( self, dataset: Dataset | IterableDataset, @@ -1960,46 +2475,23 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - on_policy = False - if random.random() <= self.lmbda: - on_policy = True - if self.use_vllm: - self._wake_vllm_if_needed() - result = self._generate_on_policy_outputs_vllm( - inputs, self.generation_config, self.processing_class.pad_token_id - ) - new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result - else: - with ( - unwrap_model_for_generation( - model, - self.accelerator, - generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 - ) as unwrapped_model - ): - result = self.generate_on_policy_outputs( - unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id - ) - new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + spg = self.args.steps_per_generation + ga = max(1, int(self.args.gradient_accumulation_steps)) + + loss = super().training_step(model, inputs, num_items_in_batch) - inputs["input_ids"] = new_input_ids - inputs["attention_mask"] = new_attention_mask - inputs["labels"] = new_labels + slice_idx = (self._step - 1) % spg - # CRITICAL: Preserve original text for cross-tokenizer ULD loss - # This ensures both off-policy (dataset) and on-policy (generated) samples - # can use proper text-based alignment for different tokenizers - inputs["original_prompt_text"] = prompt_texts - inputs["original_completion_text"] = completion_texts + on_policy = False + if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy): + on_policy = self._buffered_on_policy[slice_idx] - # Log prompt and completion texts + if on_policy and self._buffered_text_logs is not None and self._buffered_text_logs[slice_idx] is not None: + prompt_texts, completion_texts = self._buffered_text_logs[slice_idx] self._textual_logs["prompt"].extend(gather_object(prompt_texts)) self._textual_logs["completion"].extend(gather_object(completion_texts)) - loss = super().training_step(model, inputs, num_items_in_batch) - loss_scalar = float(loss.detach()) - ga = max(1, int(self.args.gradient_accumulation_steps)) step_equiv = 1.0 / ga if on_policy: From 904378bb659c8e3cc8b39d7f683a8476779bb03d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 18 Feb 2026 12:40:43 +0100 Subject: [PATCH 02/32] Clean up code from KD buffer --- trl/experimental/gold/gold_trainer.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 99256b362a8..1d1192008f3 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -22,7 +22,6 @@ from functools import partial from typing import Any, Optional -import datasets import torch import torch.distributed as dist import torch.nn as nn @@ -1063,13 +1062,6 @@ def _set_signature_columns_if_needed(self): if column not in self._signature_columns: self._signature_columns.append(column) - def on_epoch_begin(self, args, state, control, **kwargs): - self._step = 0 - self._buffered_inputs = None - self._buffered_on_policy = None - self._buffered_text_logs = None - return control - def _get_train_sampler(self, dataset=None): if dataset is None: dataset = self.train_dataset @@ -1096,7 +1088,7 @@ def get_train_dataloader(self): train_dataset = self.train_dataset data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + if is_datasets_available() and isinstance(train_dataset, Dataset): train_dataset = self._remove_unused_columns(train_dataset, description="training") else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") @@ -1291,13 +1283,7 @@ def _generate_on_policy_for_slices( def _deduplicate_prompts( prompts: list[str], num_generations: int ) -> tuple[list[str], list[tuple[int, int]]] | None: - """ - Deduplicate repeated prompts and build a mapping for remapping completions. - - When ``num_generations > 1``, the ``RepeatSampler`` produces K copies of each prompt. - Instead of sending K copies to vLLM with ``n=1``, we send unique prompts with ``n=K``, - which is more efficient. - """ + """Deduplicate prompts and build a completion remapping.""" seen: dict[str, list[int]] = {} unique_prompts: list[str] = [] dedup_mapping: list[tuple[int, int]] = [] From 6a2ece5faeb74b908c4837c5eea4fd840d94f34a Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 18 Feb 2026 14:26:19 +0100 Subject: [PATCH 03/32] Test scripts for trial run --- scripts/slurm/gold_buffer_test.slurm | 235 +++++++++++++++++++ trl/experimental/gold/gold_buffer_test.py | 274 ++++++++++++++++++++++ 2 files changed, 509 insertions(+) create mode 100755 scripts/slurm/gold_buffer_test.slurm create mode 100644 trl/experimental/gold/gold_buffer_test.py diff --git a/scripts/slurm/gold_buffer_test.slurm b/scripts/slurm/gold_buffer_test.slurm new file mode 100755 index 00000000000..033a53c2ffe --- /dev/null +++ b/scripts/slurm/gold_buffer_test.slurm @@ -0,0 +1,235 @@ +#!/bin/bash +#SBATCH --job-name=trl-gold-buffer-test +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH --output=/fsx/h4/logs/%x-%j.out +#SBATCH --error=/fsx/h4/logs/%x-%j.err +#SBATCH --requeue +#SBATCH --time=0-12:00:00 + +set -euo pipefail + +if [[ "$*" == *"--help"* ]]; then + cat <<'EOF' +Usage: sbatch scripts/slurm/gold_buffer_test.slurm [options] + +Required: + --config PATH YAML config passed to gold_buffer_test.py + +Optional: + --accelerator NAME|PATH Accelerate config name (default: zero3) or explicit YAML path + --dp N vLLM server data parallel size (default: 1) + --tp N vLLM server tensor parallel size (default: 1) + --gpus-per-node N GPUs per node for training world-size math (default: 8) + --max-steps N Inject --max_steps=N if config/args do not already set it (default: 5) + --venv PATH Virtual env to activate before launch + --args "ARGS" Extra args appended to gold_buffer_test.py + +Examples: + sbatch scripts/slurm/gold_buffer_test.slurm \ + --config /path/to/config.yaml \ + --accelerator zero3 \ + --args "--bf16 --logging_steps 1" +EOF + exit 0 +fi + +# Cluster/environment setup (same style as internal launcher) +module load cuda/12.9 || true +source ~/.bashrc || true + +START_TIME=$(date +%s) +echo "START TIME: $(date)" + +CONFIG_FILE="" +ACCELERATOR="zero3" +DP=1 +TP=1 +GPUS_PER_NODE=8 +MAX_STEPS_DEFAULT=5 +OPTIONAL_ARGS="" +VENV_PATH="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --config) + CONFIG_FILE="$2" + shift 2 + ;; + --accelerator) + ACCELERATOR="$2" + shift 2 + ;; + --dp) + DP="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --gpus-per-node) + GPUS_PER_NODE="$2" + shift 2 + ;; + --max-steps) + MAX_STEPS_DEFAULT="$2" + shift 2 + ;; + --venv) + VENV_PATH="$2" + shift 2 + ;; + --args) + OPTIONAL_ARGS="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + echo "Run with --help for usage." + exit 1 + ;; + esac +done + +if [[ -z "$CONFIG_FILE" ]]; then + echo "Error: --config is required." + exit 1 +fi + +if [[ ! -f "$CONFIG_FILE" ]]; then + echo "Error: config file not found: $CONFIG_FILE" + exit 1 +fi + +if [[ -n "$VENV_PATH" ]]; then + source "$VENV_PATH/bin/activate" +fi + +if ! command -v accelerate >/dev/null 2>&1; then + echo "Error: accelerate is not available in PATH." + exit 1 +fi + +if ! command -v trl >/dev/null 2>&1; then + echo "Error: trl CLI is not available in PATH." + exit 1 +fi + +# Resolve accelerate config. +if [[ -f "$ACCELERATOR" ]]; then + ACCEL_CONFIG="$ACCELERATOR" +elif [[ -f "trl/accelerate_configs/${ACCELERATOR}.yaml" ]]; then + ACCEL_CONFIG="trl/accelerate_configs/${ACCELERATOR}.yaml" +elif [[ -f "examples/accelerate_configs/${ACCELERATOR}.yaml" ]]; then + ACCEL_CONFIG="examples/accelerate_configs/${ACCELERATOR}.yaml" +else + echo "Error: could not resolve accelerate config from '$ACCELERATOR'." + exit 1 +fi + +GRAD_ACC_STEPS=$(grep -E '^\s*gradient_accumulation_steps:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} + +# Allow CLI override from --args. +if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps=([0-9]+) ]]; then + GRAD_ACC_STEPS="${BASH_REMATCH[1]}" +fi +if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps[[:space:]]+([0-9]+) ]]; then + GRAD_ACC_STEPS="${BASH_REMATCH[1]}" +fi + +STUDENT_MODEL=$(grep -E '^\s*model_name_or_path:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +REVISION=$(grep -E '^\s*model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +if [[ -z "${REVISION:-}" ]]; then + REVISION=$(grep -E '^\s*student_model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +fi + +if [[ -z "${SLURM_JOB_NODELIST:-}" ]]; then + echo "Error: this launcher must run inside a SLURM allocation." + exit 1 +fi + +NUM_NODES=${SLURM_NNODES:-1} +WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) +NODELIST=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) +MASTER_ADDR=${NODELIST[0]} +MASTER_PORT=${MASTER_PORT:-6000} +TRAIN_NODES=("${NODELIST[@]}") + +USE_VLLM="false" +if grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE" && grep -qE '^\s*vllm_mode:\s*server' "$CONFIG_FILE"; then + USE_VLLM="true" +fi + +if [[ "$USE_VLLM" == "true" ]]; then + if (( NUM_NODES < 2 )); then + echo "Error: vLLM server mode requires at least 2 nodes (one reserved for server)." + exit 1 + fi + + TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}") + VLLM_NODE=${NODELIST[-1]} + WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE)) + NUM_NODES=$((NUM_NODES - 1)) + + VLLM_PORT=$(grep -E '^\s*vllm_server_port:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') + VLLM_PORT=${VLLM_PORT:-8001} + + srun --nodes=1 --ntasks=1 --nodelist="$VLLM_NODE" \ + trl vllm-serve \ + --model "$STUDENT_MODEL" \ + ${REVISION:+--revision "$REVISION"} \ + --tensor_parallel_size "$TP" \ + --data_parallel_size "$DP" \ + --host "$VLLM_NODE" \ + --port "$VLLM_PORT" & + + OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE --vllm_server_port=$VLLM_PORT" +fi + +# For a test launcher, inject short max_steps if caller/config didn't already set one. +if (( MAX_STEPS_DEFAULT > 0 )); then + if ! grep -qE '^\s*max_steps:' "$CONFIG_FILE" && [[ "$OPTIONAL_ARGS" != *"--max_steps"* ]]; then + OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps=$MAX_STEPS_DEFAULT" + fi +fi + +NODELIST_CSV=$(IFS=,; echo "${TRAIN_NODES[*]}") + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 + +SCRIPT_PATH="trl/experimental/gold/gold_buffer_test.py" +LAUNCH_CMD="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ + --config_file $ACCEL_CONFIG \ + --gradient_accumulation_steps $GRAD_ACC_STEPS \ + --num_machines $NUM_NODES \ + --num_processes $WORLD_SIZE \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --machine_rank \$SLURM_PROCID \ + --rdzv_backend=c10d \ + --max_restarts 1 \ + --tee 3 \ + $SCRIPT_PATH --config $CONFIG_FILE $OPTIONAL_ARGS" + +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + --nodes=$NUM_NODES \ + --ntasks=$NUM_NODES \ + --nodelist=$NODELIST_CSV" + +set -x +clear +srun $SRUN_ARGS bash -lc "$LAUNCH_CMD" 2>&1 + +END_TIME=$(date +%s) +echo "END TIME: $(date)" +ELAPSED_SECONDS=$((END_TIME - START_TIME)) +HOURS=$((ELAPSED_SECONDS / 3600)) +MINUTES=$(((ELAPSED_SECONDS % 3600) / 60)) +SECONDS=$((ELAPSED_SECONDS % 60)) +echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)" diff --git a/trl/experimental/gold/gold_buffer_test.py b/trl/experimental/gold/gold_buffer_test.py new file mode 100644 index 00000000000..bb80d7aa4a0 --- /dev/null +++ b/trl/experimental/gold/gold_buffer_test.py @@ -0,0 +1,274 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# ] +# /// + +""" +Buffered GOLD trainer smoke-test script. + +Example (CLI args): +python trl/experimental/gold/gold_buffer_test.py \ + --model_name_or_path HuggingFaceH4/KD-Thinky \ + --teacher_model_name_or_path Qwen/Qwen3-8B \ + --dataset_name HuggingFaceH4/DeepMath-103K \ + --dataset_config trl_all \ + --output_dir data/gold-buffer-test \ + --max_steps 5 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --steps_per_generation 4 \ + --num_generations 4 \ + --lmbda 1.0 \ + --bf16 + +Example (YAML config inspired by internal recipes): +python trl/experimental/gold/gold_buffer_test.py --config path/to/config.yaml +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from datasets import Dataset, DatasetDict, IterableDataset, load_dataset +from transformers import AutoTokenizer, TrainerCallback + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + ScriptArguments, + TrlParser, + get_dataset, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.experimental.gold import GOLDConfig, GOLDTrainer + + +logger = logging.getLogger(__name__) + + +@dataclass +class GoldBufferTestArguments: + dataset_mixture: dict[str, Any] | None = field( + default=None, + metadata={ + "help": ( + "Dataset mixture config. Supports both public format (`datasets`) and internal-like format " + "(`dataset_mixture.datasets` with `id`/`config`)." + ) + }, + ) + max_train_samples: int | None = field( + default=64, + metadata={"help": "Optional cap on train samples for quick smoke tests."}, + ) + max_eval_samples: int | None = field( + default=32, + metadata={"help": "Optional cap on eval samples for quick smoke tests."}, + ) + require_buffer_usage: bool = field( + default=True, + metadata={"help": "Fail if buffered generation path is not observed when steps_per_generation > 1."}, + ) + + +class BufferSanityCallback(TrainerCallback): + def __init__(self, trainer: GOLDTrainer, require_buffer_usage: bool = True): + self.trainer = trainer + self.require_buffer_usage = require_buffer_usage + self.buffer_seen = False + + def on_step_end(self, args, state, control, **kwargs): + steps_per_generation = max(1, int(self.trainer.args.steps_per_generation)) + if steps_per_generation <= 1: + return control + buffered_inputs = getattr(self.trainer, "_buffered_inputs", None) + buffered_flags = getattr(self.trainer, "_buffered_on_policy", None) + if ( + isinstance(buffered_inputs, list) + and isinstance(buffered_flags, list) + and len(buffered_inputs) == steps_per_generation + and len(buffered_flags) == steps_per_generation + ): + self.buffer_seen = True + return control + + def on_train_end(self, args, state, control, **kwargs): + steps_per_generation = max(1, int(self.trainer.args.steps_per_generation)) + if self.require_buffer_usage and steps_per_generation > 1 and not self.buffer_seen: + raise RuntimeError( + "Buffer sanity check failed: trainer did not expose buffered rollout state while " + "steps_per_generation > 1." + ) + return control + + +def _normalize_internal_like_mixture(raw: dict[str, Any]) -> DatasetMixtureConfig: + datasets_raw = raw.get("datasets", []) + normalized_datasets = [] + for entry in datasets_raw: + path = entry.get("path", entry.get("id")) + name = entry.get("name", entry.get("config")) + if path is None: + raise ValueError(f"Each dataset entry must provide `path` or `id`. Got: {entry}") + if "weight" in entry: + logger.warning("Ignoring dataset `weight`=%s for %s in smoke-test script.", entry["weight"], path) + normalized_datasets.append( + { + "path": path, + "name": name, + "data_dir": entry.get("data_dir"), + "data_files": entry.get("data_files"), + "split": entry.get("split", "train"), + "columns": entry.get("columns"), + } + ) + + return DatasetMixtureConfig( + datasets=normalized_datasets, + streaming=raw.get("streaming", False), + test_split_size=raw.get("test_split_size"), + ) + + +def _resolve_dataset( + script_args: ScriptArguments, + test_args: GoldBufferTestArguments, +) -> DatasetDict: + if test_args.dataset_mixture is not None: + mixture = _normalize_internal_like_mixture(test_args.dataset_mixture) + return get_dataset(mixture) + + if script_args.dataset_name is None: + raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided.") + return load_dataset( + script_args.dataset_name, + name=script_args.dataset_config, + streaming=script_args.dataset_streaming, + ) + + +def _cap_dataset_size(dataset: Dataset | IterableDataset, cap: int | None): + if cap is None: + return dataset + if isinstance(dataset, IterableDataset): + return dataset.take(cap) + cap = min(cap, len(dataset)) + return dataset.select(range(cap)) + + +def _pick_split(dataset_dict: DatasetDict, split_name: str, fallbacks: tuple[str, ...]): + if split_name in dataset_dict: + return dataset_dict[split_name] + for candidate in fallbacks: + if candidate in dataset_dict: + logger.warning("Split `%s` not found. Falling back to `%s`.", split_name, candidate) + return dataset_dict[candidate] + raise ValueError(f"Split `{split_name}` not found. Available splits: {list(dataset_dict.keys())}") + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, GOLDConfig, ModelConfig, GoldBufferTestArguments)) + script_args, training_args, model_args, test_args, _ = parser.parse_args_and_config(return_remaining_strings=True) + + if training_args.student_model_revision in (None, "main") and model_args.model_revision is not None: + training_args.student_model_revision = model_args.model_revision + + quantization_config = get_quantization_config(model_args) + + model_kwargs = dict(training_args.model_init_kwargs or {}) + model_kwargs.setdefault("revision", training_args.student_model_revision) + model_kwargs.setdefault("trust_remote_code", model_args.trust_remote_code) + model_kwargs.setdefault("attn_implementation", model_args.attn_implementation) + model_kwargs.setdefault("torch_dtype", model_args.dtype) + model_kwargs.setdefault("use_cache", False if training_args.gradient_checkpointing else True) + if quantization_config is not None: + model_kwargs.setdefault("device_map", get_kbit_device_map()) + model_kwargs.setdefault("quantization_config", quantization_config) + training_args.model_init_kwargs = model_kwargs + + if training_args.teacher_model_name_or_path is None: + training_args.teacher_model_name_or_path = model_args.model_name_or_path + if training_args.use_uld_loss and training_args.teacher_tokenizer_name_or_path is None: + training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path + + teacher_model_kwargs = dict(training_args.teacher_model_init_kwargs or {}) + teacher_model_kwargs.setdefault("revision", model_args.model_revision) + teacher_model_kwargs.setdefault("trust_remote_code", model_args.trust_remote_code) + teacher_model_kwargs.setdefault("attn_implementation", model_args.attn_implementation) + teacher_model_kwargs.setdefault("torch_dtype", model_args.dtype) + teacher_model_kwargs.setdefault("use_cache", True) + if quantization_config is not None: + teacher_model_kwargs.setdefault("device_map", get_kbit_device_map()) + teacher_model_kwargs.setdefault("quantization_config", quantization_config) + training_args.teacher_model_init_kwargs = teacher_model_kwargs + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + dataset_dict = _resolve_dataset(script_args, test_args) + train_dataset = _pick_split(dataset_dict, script_args.dataset_train_split, ("train",)) + train_dataset = _cap_dataset_size(train_dataset, test_args.max_train_samples) + + eval_dataset = None + if training_args.eval_strategy != "no": + eval_dataset = _pick_split(dataset_dict, script_args.dataset_test_split, ("validation", "dev", "test")) + eval_dataset = _cap_dataset_size(eval_dataset, test_args.max_eval_samples) + + trainer = GOLDTrainer( + model=model_args.model_name_or_path, + teacher_model=training_args.teacher_model_name_or_path, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + buffer_sanity = BufferSanityCallback(trainer, require_buffer_usage=test_args.require_buffer_usage) + trainer.add_callback(buffer_sanity) + + logger.info( + "Starting GOLD buffer test: steps_per_generation=%s, num_generations=%s, lmbda=%s, use_vllm=%s", + training_args.steps_per_generation, + training_args.num_generations, + training_args.lmbda, + training_args.use_vllm, + ) + + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + metrics = dict(train_result.metrics) + metrics["buffer_seen"] = int(buffer_sanity.buffer_seen) + if not isinstance(train_dataset, IterableDataset): + metrics["train_samples"] = len(train_dataset) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) From ee07aece6bb455ed1c04556b1b88c239072acafa Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 23 Feb 2026 16:17:09 +0000 Subject: [PATCH 04/32] Apply fixes to the Liger loss setting Avoid crashing when using DeepSpeed ZeRO-3 and set up the correct values for `weight_hard_loss` and `weight_soft_loss` --- trl/experimental/gold/gold_trainer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 1d1192008f3..676130dd62a 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -824,6 +824,8 @@ def __init__( ignore_index=-100, temperature=args.temperature, compiled=False, + weight_hard_loss=0.0, + weight_soft_loss=1.0, ) self.use_liger_gkd_loss = True @@ -1803,6 +1805,7 @@ def tokenize_with_original_text(example, processing_class, dataset_text_field, a "attention_mask", "position_ids", "completion_mask", + "messages", "assistant_masks", "original_prompt_text", "original_completion_text", @@ -1982,12 +1985,19 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Release full outputs to free memory del student_outputs, teacher_outputs + # Flatten to (batch_size * seq_len, hidden_size) so liger_jsd_loss chunks + # on the token dimension (shape[0]) rather than the batch dimension. + # Without this, num_chunks = max(1, batch_size // 1024) = 1, meaning the + # full [batch_size, seq_len, vocab_size] logit tensor is materialised at once. + student_hidden = student_hidden.reshape(-1, student_hidden.shape[-1]) + teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) + # labels mask and labels (shifted) labels_mask = inputs["labels"] != -100 masked_input_ids = torch.where( labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) ) - true_labels = masked_input_ids[:, 1:].contiguous() + true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) # heads student_head = unwrapped_student.get_output_embeddings() @@ -2029,6 +2039,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_logits=shifted_teacher_logits, labels=shifted_labels, beta=self.beta, + temperature=self.temperature, ) if self.use_uld_loss: From a3fd2af6e2575e3330911db669219bbc78b392f7 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 25 Feb 2026 18:00:22 +0000 Subject: [PATCH 05/32] Remove test scripts --- scripts/slurm/gold_buffer_test.slurm | 235 ------------------- trl/experimental/gold/gold_buffer_test.py | 274 ---------------------- 2 files changed, 509 deletions(-) delete mode 100755 scripts/slurm/gold_buffer_test.slurm delete mode 100644 trl/experimental/gold/gold_buffer_test.py diff --git a/scripts/slurm/gold_buffer_test.slurm b/scripts/slurm/gold_buffer_test.slurm deleted file mode 100755 index 033a53c2ffe..00000000000 --- a/scripts/slurm/gold_buffer_test.slurm +++ /dev/null @@ -1,235 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=trl-gold-buffer-test -#SBATCH --ntasks-per-node=1 -#SBATCH --exclusive -#SBATCH --gres=gpu:8 -#SBATCH --partition=hopper-prod -#SBATCH --output=/fsx/h4/logs/%x-%j.out -#SBATCH --error=/fsx/h4/logs/%x-%j.err -#SBATCH --requeue -#SBATCH --time=0-12:00:00 - -set -euo pipefail - -if [[ "$*" == *"--help"* ]]; then - cat <<'EOF' -Usage: sbatch scripts/slurm/gold_buffer_test.slurm [options] - -Required: - --config PATH YAML config passed to gold_buffer_test.py - -Optional: - --accelerator NAME|PATH Accelerate config name (default: zero3) or explicit YAML path - --dp N vLLM server data parallel size (default: 1) - --tp N vLLM server tensor parallel size (default: 1) - --gpus-per-node N GPUs per node for training world-size math (default: 8) - --max-steps N Inject --max_steps=N if config/args do not already set it (default: 5) - --venv PATH Virtual env to activate before launch - --args "ARGS" Extra args appended to gold_buffer_test.py - -Examples: - sbatch scripts/slurm/gold_buffer_test.slurm \ - --config /path/to/config.yaml \ - --accelerator zero3 \ - --args "--bf16 --logging_steps 1" -EOF - exit 0 -fi - -# Cluster/environment setup (same style as internal launcher) -module load cuda/12.9 || true -source ~/.bashrc || true - -START_TIME=$(date +%s) -echo "START TIME: $(date)" - -CONFIG_FILE="" -ACCELERATOR="zero3" -DP=1 -TP=1 -GPUS_PER_NODE=8 -MAX_STEPS_DEFAULT=5 -OPTIONAL_ARGS="" -VENV_PATH="" - -while [[ $# -gt 0 ]]; do - case "$1" in - --config) - CONFIG_FILE="$2" - shift 2 - ;; - --accelerator) - ACCELERATOR="$2" - shift 2 - ;; - --dp) - DP="$2" - shift 2 - ;; - --tp) - TP="$2" - shift 2 - ;; - --gpus-per-node) - GPUS_PER_NODE="$2" - shift 2 - ;; - --max-steps) - MAX_STEPS_DEFAULT="$2" - shift 2 - ;; - --venv) - VENV_PATH="$2" - shift 2 - ;; - --args) - OPTIONAL_ARGS="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" - echo "Run with --help for usage." - exit 1 - ;; - esac -done - -if [[ -z "$CONFIG_FILE" ]]; then - echo "Error: --config is required." - exit 1 -fi - -if [[ ! -f "$CONFIG_FILE" ]]; then - echo "Error: config file not found: $CONFIG_FILE" - exit 1 -fi - -if [[ -n "$VENV_PATH" ]]; then - source "$VENV_PATH/bin/activate" -fi - -if ! command -v accelerate >/dev/null 2>&1; then - echo "Error: accelerate is not available in PATH." - exit 1 -fi - -if ! command -v trl >/dev/null 2>&1; then - echo "Error: trl CLI is not available in PATH." - exit 1 -fi - -# Resolve accelerate config. -if [[ -f "$ACCELERATOR" ]]; then - ACCEL_CONFIG="$ACCELERATOR" -elif [[ -f "trl/accelerate_configs/${ACCELERATOR}.yaml" ]]; then - ACCEL_CONFIG="trl/accelerate_configs/${ACCELERATOR}.yaml" -elif [[ -f "examples/accelerate_configs/${ACCELERATOR}.yaml" ]]; then - ACCEL_CONFIG="examples/accelerate_configs/${ACCELERATOR}.yaml" -else - echo "Error: could not resolve accelerate config from '$ACCELERATOR'." - exit 1 -fi - -GRAD_ACC_STEPS=$(grep -E '^\s*gradient_accumulation_steps:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} - -# Allow CLI override from --args. -if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps=([0-9]+) ]]; then - GRAD_ACC_STEPS="${BASH_REMATCH[1]}" -fi -if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps[[:space:]]+([0-9]+) ]]; then - GRAD_ACC_STEPS="${BASH_REMATCH[1]}" -fi - -STUDENT_MODEL=$(grep -E '^\s*model_name_or_path:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -REVISION=$(grep -E '^\s*model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -if [[ -z "${REVISION:-}" ]]; then - REVISION=$(grep -E '^\s*student_model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -fi - -if [[ -z "${SLURM_JOB_NODELIST:-}" ]]; then - echo "Error: this launcher must run inside a SLURM allocation." - exit 1 -fi - -NUM_NODES=${SLURM_NNODES:-1} -WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) -NODELIST=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) -MASTER_ADDR=${NODELIST[0]} -MASTER_PORT=${MASTER_PORT:-6000} -TRAIN_NODES=("${NODELIST[@]}") - -USE_VLLM="false" -if grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE" && grep -qE '^\s*vllm_mode:\s*server' "$CONFIG_FILE"; then - USE_VLLM="true" -fi - -if [[ "$USE_VLLM" == "true" ]]; then - if (( NUM_NODES < 2 )); then - echo "Error: vLLM server mode requires at least 2 nodes (one reserved for server)." - exit 1 - fi - - TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}") - VLLM_NODE=${NODELIST[-1]} - WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE)) - NUM_NODES=$((NUM_NODES - 1)) - - VLLM_PORT=$(grep -E '^\s*vllm_server_port:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') - VLLM_PORT=${VLLM_PORT:-8001} - - srun --nodes=1 --ntasks=1 --nodelist="$VLLM_NODE" \ - trl vllm-serve \ - --model "$STUDENT_MODEL" \ - ${REVISION:+--revision "$REVISION"} \ - --tensor_parallel_size "$TP" \ - --data_parallel_size "$DP" \ - --host "$VLLM_NODE" \ - --port "$VLLM_PORT" & - - OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE --vllm_server_port=$VLLM_PORT" -fi - -# For a test launcher, inject short max_steps if caller/config didn't already set one. -if (( MAX_STEPS_DEFAULT > 0 )); then - if ! grep -qE '^\s*max_steps:' "$CONFIG_FILE" && [[ "$OPTIONAL_ARGS" != *"--max_steps"* ]]; then - OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps=$MAX_STEPS_DEFAULT" - fi -fi - -NODELIST_CSV=$(IFS=,; echo "${TRAIN_NODES[*]}") - -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 - -SCRIPT_PATH="trl/experimental/gold/gold_buffer_test.py" -LAUNCH_CMD="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ - --config_file $ACCEL_CONFIG \ - --gradient_accumulation_steps $GRAD_ACC_STEPS \ - --num_machines $NUM_NODES \ - --num_processes $WORLD_SIZE \ - --main_process_ip $MASTER_ADDR \ - --main_process_port $MASTER_PORT \ - --machine_rank \$SLURM_PROCID \ - --rdzv_backend=c10d \ - --max_restarts 1 \ - --tee 3 \ - $SCRIPT_PATH --config $CONFIG_FILE $OPTIONAL_ARGS" - -SRUN_ARGS=" \ - --wait=60 \ - --kill-on-bad-exit=1 \ - --nodes=$NUM_NODES \ - --ntasks=$NUM_NODES \ - --nodelist=$NODELIST_CSV" - -set -x -clear -srun $SRUN_ARGS bash -lc "$LAUNCH_CMD" 2>&1 - -END_TIME=$(date +%s) -echo "END TIME: $(date)" -ELAPSED_SECONDS=$((END_TIME - START_TIME)) -HOURS=$((ELAPSED_SECONDS / 3600)) -MINUTES=$(((ELAPSED_SECONDS % 3600) / 60)) -SECONDS=$((ELAPSED_SECONDS % 60)) -echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)" diff --git a/trl/experimental/gold/gold_buffer_test.py b/trl/experimental/gold/gold_buffer_test.py deleted file mode 100644 index bb80d7aa4a0..00000000000 --- a/trl/experimental/gold/gold_buffer_test.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2020-2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# /// script -# dependencies = [ -# "trl", -# "peft", -# "trackio", -# ] -# /// - -""" -Buffered GOLD trainer smoke-test script. - -Example (CLI args): -python trl/experimental/gold/gold_buffer_test.py \ - --model_name_or_path HuggingFaceH4/KD-Thinky \ - --teacher_model_name_or_path Qwen/Qwen3-8B \ - --dataset_name HuggingFaceH4/DeepMath-103K \ - --dataset_config trl_all \ - --output_dir data/gold-buffer-test \ - --max_steps 5 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 4 \ - --steps_per_generation 4 \ - --num_generations 4 \ - --lmbda 1.0 \ - --bf16 - -Example (YAML config inspired by internal recipes): -python trl/experimental/gold/gold_buffer_test.py --config path/to/config.yaml -""" - -import logging -from dataclasses import dataclass, field -from typing import Any - -from datasets import Dataset, DatasetDict, IterableDataset, load_dataset -from transformers import AutoTokenizer, TrainerCallback - -from trl import ( - DatasetMixtureConfig, - ModelConfig, - ScriptArguments, - TrlParser, - get_dataset, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) -from trl.experimental.gold import GOLDConfig, GOLDTrainer - - -logger = logging.getLogger(__name__) - - -@dataclass -class GoldBufferTestArguments: - dataset_mixture: dict[str, Any] | None = field( - default=None, - metadata={ - "help": ( - "Dataset mixture config. Supports both public format (`datasets`) and internal-like format " - "(`dataset_mixture.datasets` with `id`/`config`)." - ) - }, - ) - max_train_samples: int | None = field( - default=64, - metadata={"help": "Optional cap on train samples for quick smoke tests."}, - ) - max_eval_samples: int | None = field( - default=32, - metadata={"help": "Optional cap on eval samples for quick smoke tests."}, - ) - require_buffer_usage: bool = field( - default=True, - metadata={"help": "Fail if buffered generation path is not observed when steps_per_generation > 1."}, - ) - - -class BufferSanityCallback(TrainerCallback): - def __init__(self, trainer: GOLDTrainer, require_buffer_usage: bool = True): - self.trainer = trainer - self.require_buffer_usage = require_buffer_usage - self.buffer_seen = False - - def on_step_end(self, args, state, control, **kwargs): - steps_per_generation = max(1, int(self.trainer.args.steps_per_generation)) - if steps_per_generation <= 1: - return control - buffered_inputs = getattr(self.trainer, "_buffered_inputs", None) - buffered_flags = getattr(self.trainer, "_buffered_on_policy", None) - if ( - isinstance(buffered_inputs, list) - and isinstance(buffered_flags, list) - and len(buffered_inputs) == steps_per_generation - and len(buffered_flags) == steps_per_generation - ): - self.buffer_seen = True - return control - - def on_train_end(self, args, state, control, **kwargs): - steps_per_generation = max(1, int(self.trainer.args.steps_per_generation)) - if self.require_buffer_usage and steps_per_generation > 1 and not self.buffer_seen: - raise RuntimeError( - "Buffer sanity check failed: trainer did not expose buffered rollout state while " - "steps_per_generation > 1." - ) - return control - - -def _normalize_internal_like_mixture(raw: dict[str, Any]) -> DatasetMixtureConfig: - datasets_raw = raw.get("datasets", []) - normalized_datasets = [] - for entry in datasets_raw: - path = entry.get("path", entry.get("id")) - name = entry.get("name", entry.get("config")) - if path is None: - raise ValueError(f"Each dataset entry must provide `path` or `id`. Got: {entry}") - if "weight" in entry: - logger.warning("Ignoring dataset `weight`=%s for %s in smoke-test script.", entry["weight"], path) - normalized_datasets.append( - { - "path": path, - "name": name, - "data_dir": entry.get("data_dir"), - "data_files": entry.get("data_files"), - "split": entry.get("split", "train"), - "columns": entry.get("columns"), - } - ) - - return DatasetMixtureConfig( - datasets=normalized_datasets, - streaming=raw.get("streaming", False), - test_split_size=raw.get("test_split_size"), - ) - - -def _resolve_dataset( - script_args: ScriptArguments, - test_args: GoldBufferTestArguments, -) -> DatasetDict: - if test_args.dataset_mixture is not None: - mixture = _normalize_internal_like_mixture(test_args.dataset_mixture) - return get_dataset(mixture) - - if script_args.dataset_name is None: - raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided.") - return load_dataset( - script_args.dataset_name, - name=script_args.dataset_config, - streaming=script_args.dataset_streaming, - ) - - -def _cap_dataset_size(dataset: Dataset | IterableDataset, cap: int | None): - if cap is None: - return dataset - if isinstance(dataset, IterableDataset): - return dataset.take(cap) - cap = min(cap, len(dataset)) - return dataset.select(range(cap)) - - -def _pick_split(dataset_dict: DatasetDict, split_name: str, fallbacks: tuple[str, ...]): - if split_name in dataset_dict: - return dataset_dict[split_name] - for candidate in fallbacks: - if candidate in dataset_dict: - logger.warning("Split `%s` not found. Falling back to `%s`.", split_name, candidate) - return dataset_dict[candidate] - raise ValueError(f"Split `{split_name}` not found. Available splits: {list(dataset_dict.keys())}") - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, GOLDConfig, ModelConfig, GoldBufferTestArguments)) - script_args, training_args, model_args, test_args, _ = parser.parse_args_and_config(return_remaining_strings=True) - - if training_args.student_model_revision in (None, "main") and model_args.model_revision is not None: - training_args.student_model_revision = model_args.model_revision - - quantization_config = get_quantization_config(model_args) - - model_kwargs = dict(training_args.model_init_kwargs or {}) - model_kwargs.setdefault("revision", training_args.student_model_revision) - model_kwargs.setdefault("trust_remote_code", model_args.trust_remote_code) - model_kwargs.setdefault("attn_implementation", model_args.attn_implementation) - model_kwargs.setdefault("torch_dtype", model_args.dtype) - model_kwargs.setdefault("use_cache", False if training_args.gradient_checkpointing else True) - if quantization_config is not None: - model_kwargs.setdefault("device_map", get_kbit_device_map()) - model_kwargs.setdefault("quantization_config", quantization_config) - training_args.model_init_kwargs = model_kwargs - - if training_args.teacher_model_name_or_path is None: - training_args.teacher_model_name_or_path = model_args.model_name_or_path - if training_args.use_uld_loss and training_args.teacher_tokenizer_name_or_path is None: - training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path - - teacher_model_kwargs = dict(training_args.teacher_model_init_kwargs or {}) - teacher_model_kwargs.setdefault("revision", model_args.model_revision) - teacher_model_kwargs.setdefault("trust_remote_code", model_args.trust_remote_code) - teacher_model_kwargs.setdefault("attn_implementation", model_args.attn_implementation) - teacher_model_kwargs.setdefault("torch_dtype", model_args.dtype) - teacher_model_kwargs.setdefault("use_cache", True) - if quantization_config is not None: - teacher_model_kwargs.setdefault("device_map", get_kbit_device_map()) - teacher_model_kwargs.setdefault("quantization_config", quantization_config) - training_args.teacher_model_init_kwargs = teacher_model_kwargs - - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - padding_side="left", - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - dataset_dict = _resolve_dataset(script_args, test_args) - train_dataset = _pick_split(dataset_dict, script_args.dataset_train_split, ("train",)) - train_dataset = _cap_dataset_size(train_dataset, test_args.max_train_samples) - - eval_dataset = None - if training_args.eval_strategy != "no": - eval_dataset = _pick_split(dataset_dict, script_args.dataset_test_split, ("validation", "dev", "test")) - eval_dataset = _cap_dataset_size(eval_dataset, test_args.max_eval_samples) - - trainer = GOLDTrainer( - model=model_args.model_name_or_path, - teacher_model=training_args.teacher_model_name_or_path, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - processing_class=tokenizer, - peft_config=get_peft_config(model_args), - ) - - buffer_sanity = BufferSanityCallback(trainer, require_buffer_usage=test_args.require_buffer_usage) - trainer.add_callback(buffer_sanity) - - logger.info( - "Starting GOLD buffer test: steps_per_generation=%s, num_generations=%s, lmbda=%s, use_vllm=%s", - training_args.steps_per_generation, - training_args.num_generations, - training_args.lmbda, - training_args.use_vllm, - ) - - train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - metrics = dict(train_result.metrics) - metrics["buffer_seen"] = int(buffer_sanity.buffer_seen) - if not isinstance(train_dataset, IterableDataset): - metrics["train_samples"] = len(train_dataset) - trainer.log_metrics("train", metrics) - trainer.save_metrics("train", metrics) - trainer.save_state() - - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) From b0669d99fd91f12056262a23c1478bf63c1f4469 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 25 Feb 2026 18:09:58 +0000 Subject: [PATCH 06/32] Handle config parameters better in gold script --- trl/experimental/gold/gold.py | 6 +++--- trl/experimental/gold/gold_config.py | 6 ------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 81954b5c753..561a639cfe9 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -80,7 +80,7 @@ ################ quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=training_args.student_model_revision, + revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, @@ -93,7 +93,6 @@ if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss: training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path teacher_model_kwargs = dict( - revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, @@ -101,13 +100,14 @@ device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) + if training_args.teacher_model_init_kwargs is not None: + teacher_model_kwargs.update(training_args.teacher_model_init_kwargs) training_args.teacher_model_init_kwargs = teacher_model_kwargs tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - padding_side="left", ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index ea5e621afc8..a4a4d60be96 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -389,12 +389,6 @@ class GOLDConfig(SFTConfig): num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) - trl_project: str = field( - default="smollm3", - metadata={ - "help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script." - }, - ) def __post_init__(self): super().__post_init__() From b0c4f3e4786c9ad3a9832eada9a40b22f8e34e47 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Wed, 25 Feb 2026 18:10:26 +0000 Subject: [PATCH 07/32] Upload provisional SLURM script for GOLD --- scripts/slurm/gold.slurm | 184 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100755 scripts/slurm/gold.slurm diff --git a/scripts/slurm/gold.slurm b/scripts/slurm/gold.slurm new file mode 100755 index 00000000000..aa20e7b659b --- /dev/null +++ b/scripts/slurm/gold.slurm @@ -0,0 +1,184 @@ +#!/bin/bash +#SBATCH --job-name=trl-gold-buffer-test +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH --output=logs/%x-%j.out +#SBATCH --error=logs/%x-%j.err +#SBATCH --requeue +#SBATCH --time=0-12:00:00 + +set -euo pipefail + +if [[ "$*" == *"--help"* ]]; then + cat <<'EOF' +Usage: sbatch scripts/slurm/gold_buffer_test.slurm [options] + +Required: + --config PATH YAML config passed to gold_buffer_test.py + +Optional: + --accelerator NAME|PATH Accelerate config name (default: zero3) or explicit YAML path + --dp N vLLM server data parallel size (default: 1) + --tp N vLLM server tensor parallel size (default: 1) + --args "ARGS" Extra args appended to gold_buffer_test.py + +Examples: + sbatch scripts/slurm/gold_buffer_test.slurm \ + --config /path/to/config.yaml \ + --accelerator zero3 \ + --args "--bf16 --logging_steps 1" +EOF + exit 0 +fi + +# Cluster/environment setup (same style as internal launcher) +module load cuda/12.9 || true +source ~/.bashrc || true + +START_TIME=$(date +%s) +echo "START TIME: $(date)" + +CONFIG_FILE="" +ACCELERATOR="zero3" +DP=1 +TP=1 +GPUS_PER_NODE=8 +OPTIONAL_ARGS="" +VENV_PATH="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --config) + CONFIG_FILE="$2" + shift 2 + ;; + --accelerator) + ACCELERATOR="$2" + shift 2 + ;; + --dp) + DP="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --args) + OPTIONAL_ARGS="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" + echo "Run with --help for usage." + exit 1 + ;; + esac +done + +source "trl-internal/bin/activate" + +if [[ -z "$CONFIG_FILE" ]]; then + echo "Error: --config is required." + exit 1 +fi + +if [[ ! -f "$CONFIG_FILE" ]]; then + echo "Error: config file not found: $CONFIG_FILE" + exit 1 +fi + +if ! command -v accelerate >/dev/null 2>&1; then + echo "Error: accelerate is not available in PATH." + exit 1 +fi + +if ! command -v trl >/dev/null 2>&1; then + echo "Error: trl CLI is not available in PATH." + exit 1 +fi + +# Resolve accelerate config. +if [[ -f "$ACCELERATOR" ]]; then + ACCEL_CONFIG="$ACCELERATOR" +elif [[ -f "trl/accelerate_configs/${ACCELERATOR}.yaml" ]]; then + ACCEL_CONFIG="trl/accelerate_configs/${ACCELERATOR}.yaml" +elif [[ -f "examples/accelerate_configs/${ACCELERATOR}.yaml" ]]; then + ACCEL_CONFIG="examples/accelerate_configs/${ACCELERATOR}.yaml" +else + echo "Error: could not resolve accelerate config from '$ACCELERATOR'." + exit 1 +fi + +GRAD_ACC_STEPS=$(grep -E '^\s*gradient_accumulation_steps:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} + +# Allow CLI override from --args. +if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps=([0-9]+) ]]; then + GRAD_ACC_STEPS="${BASH_REMATCH[1]}" +fi +if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps[[:space:]]+([0-9]+) ]]; then + GRAD_ACC_STEPS="${BASH_REMATCH[1]}" +fi + +STUDENT_MODEL=$(grep -E '^\s*model_name_or_path:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +REVISION=$(grep -E '^\s*model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +if [[ -z "${REVISION:-}" ]]; then + REVISION=$(grep -E '^\s*student_model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') +fi + +if [[ -z "${SLURM_JOB_NODELIST:-}" ]]; then + echo "Error: this launcher must run inside a SLURM allocation." + exit 1 +fi + +NUM_NODES=${SLURM_NNODES:-1} +WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) +NODELIST=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) +MASTER_ADDR=${NODELIST[0]} +MASTER_PORT=${MASTER_PORT:-6000} +TRAIN_NODES=("${NODELIST[@]}") + +USE_VLLM="false" +if grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE" && grep -qE '^\s*vllm_mode:\s*server' "$CONFIG_FILE"; then + USE_VLLM="true" +fi + +NODELIST_CSV=$(IFS=,; echo "${TRAIN_NODES[*]}") + +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 + +SCRIPT_PATH="trl/experimental/gold/gold.py" +LAUNCH_CMD="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ + --config_file $ACCEL_CONFIG \ + --gradient_accumulation_steps $GRAD_ACC_STEPS \ + --num_machines $NUM_NODES \ + --num_processes $WORLD_SIZE \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + --machine_rank \$SLURM_PROCID \ + --rdzv_backend=c10d \ + --max_restarts 1 \ + --tee 3 \ + $SCRIPT_PATH --config $CONFIG_FILE $OPTIONAL_ARGS" + +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + --nodes=$NUM_NODES \ + --ntasks=$NUM_NODES \ + --nodelist=$NODELIST_CSV" + +set -x +clear +srun $SRUN_ARGS bash -lc "$LAUNCH_CMD" 2>&1 + +END_TIME=$(date +%s) +echo "END TIME: $(date)" +ELAPSED_SECONDS=$((END_TIME - START_TIME)) +HOURS=$((ELAPSED_SECONDS / 3600)) +MINUTES=$(((ELAPSED_SECONDS % 3600) / 60)) +SECONDS=$((ELAPSED_SECONDS % 60)) +echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)" From 602e56462534512f0fa4da071f6005c17cc4875c Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 26 Feb 2026 10:06:10 +0000 Subject: [PATCH 08/32] Refine logic and comments --- trl/experimental/gold/gold.py | 1 - trl/experimental/gold/gold_trainer.py | 57 ++------------------------- 2 files changed, 4 insertions(+), 54 deletions(-) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 561a639cfe9..2cfa64a329b 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -120,7 +120,6 @@ ################ # Training ################ - # Handle eval dataset - check if test split exists, fallback to validation or None eval_dataset = None if training_args.eval_strategy != "no": if script_args.dataset_test_split in dataset: diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 676130dd62a..ca6c66806d3 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1131,9 +1131,7 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di return inputs def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Tensor | Any]) -> list[str] | None: - """ - Decode completion-only text from labels for cross-tokenizer ULD when raw text is not available. - """ + """Decode completion text from labels when raw text is absent.""" labels = slice_inputs.get("labels") if labels is None or not isinstance(labels, torch.Tensor): return None @@ -1153,9 +1151,7 @@ def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Ten ) def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: - """ - Ensure original prompt/completion text fields are available for ULD loss. - """ + """Populate original prompt/completion text fields when missing.""" if "original_prompt_text" in slice_inputs and "original_completion_text" in slice_inputs: return slice_inputs @@ -1562,24 +1558,15 @@ def _prepare_dataset( formatting_func: Callable[[dict], str] | None, dataset_name: str, ) -> Dataset | IterableDataset: - """ - Override dataset preparation to preserve original text for cross-tokenizer distillation and ensure - attention_mask is always added for DataCollatorForChatML compatibility. - """ - # Check if dataset is already processed + """Preserve original text fields for ULD when needed.""" column_names = list(next(iter(dataset)).keys()) is_processed = "input_ids" in column_names - # Use our enhanced dataset preparation for: - # 1. ULD loss with cross-tokenizer (need original text preservation) - # 2. Any unprocessed dataset (need attention_mask for DataCollatorForChatML) if not is_processed or (self.use_uld_loss and self.teacher_tokenizer is not None): - # For unprocessed datasets, use our enhanced tokenization return self._prepare_dataset_with_original_text( dataset, processing_class, args, packing, formatting_func, dataset_name ) - # Use parent implementation for all other cases return super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) def _prepare_dataset_with_original_text( @@ -1978,32 +1965,23 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N use_cache=False, ) - # hidden states (shifted) student_hidden = student_outputs.last_hidden_state[:, :-1] teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] - # Release full outputs to free memory del student_outputs, teacher_outputs - # Flatten to (batch_size * seq_len, hidden_size) so liger_jsd_loss chunks - # on the token dimension (shape[0]) rather than the batch dimension. - # Without this, num_chunks = max(1, batch_size // 1024) = 1, meaning the - # full [batch_size, seq_len, vocab_size] logit tensor is materialised at once. student_hidden = student_hidden.reshape(-1, student_hidden.shape[-1]) teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) - # labels mask and labels (shifted) labels_mask = inputs["labels"] != -100 masked_input_ids = torch.where( labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) ) true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) - # heads student_head = unwrapped_student.get_output_embeddings() teacher_head = unwrapped_teacher.get_output_embeddings() - # liger fused jsd loss loss = self.liger_jsd_loss( student_input=student_hidden, student_weight=student_head.weight, @@ -2014,10 +1992,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_bias=getattr(teacher_head, "bias", None), ) - # Release hidden states after loss computation del student_hidden, teacher_hidden, true_labels else: - # Original behavior for same tokenizer or when teacher_tokenizer is not provided outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -2045,16 +2021,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if self.use_uld_loss: student_input_ids = inputs["input_ids"] - # Use the *teacher* labels created above, not the student's. teacher_labels_for_loss = teacher_labels if "teacher_labels" in locals() else inputs["labels"] teacher_input_ids_for_loss = teacher_input_ids if "teacher_input_ids" in locals() else inputs["input_ids"] - # Create properly masked student labels (fixing batch size > 1 issue) student_labels = inputs["labels"].clone() if hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None: student_labels[student_labels == self.processing_class.pad_token_id] = -100 - # Also mask pad tokens in teacher labels for consistency if ( hasattr(self, "teacher_tokenizer") and hasattr(self.teacher_tokenizer, "pad_token_id") @@ -2071,15 +2044,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_input_ids=teacher_input_ids_for_loss, ) - # If ULD hybrid mode produced per-step matched/unmatched components, accumulate them for logging. - # Use gradient_accumulation_steps to mirror Trainer's windowing behavior. if hasattr(self.uld_loss_fn, "last_matched_loss") and hasattr(self.uld_loss_fn, "last_unmatched_loss"): - try: - ga = max(1, int(self.args.gradient_accumulation_steps)) - except Exception: - ga = 1 + ga = max(1, int(self.args.gradient_accumulation_steps)) step_eq = 1.0 / ga - # read scalar values for logging matched_val = ( self.uld_loss_fn.last_matched_loss.item() if self.uld_loss_fn.last_matched_loss is not None @@ -2290,11 +2257,6 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ else: raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - # We need to combine prompt and completion for new_input_ids - # Tokenize prompts again to get prompt_ids on the correct device and format - # Use prompts_text_for_vllm (without special tokens) for tokenization since vLLM expects clean text - # Ensure add_special_tokens=False as vLLM typically handles prompts as raw text - # Calculate max_length for prompts, ensuring it's positive prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None prompt_tokenized = self.processing_class( prompts_text_for_vllm, @@ -2307,14 +2269,11 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ prompt_ids = prompt_tokenized.input_ids completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids] - # Manually pad/truncate completions to max_completion_length length before using pad function padded_completion_ids_list = [] for completion_tensor in completion_ids_tensors: if len(completion_tensor) > max_completion_length: - # Truncate if longer than max_completion_length padded_completion_ids_list.append(completion_tensor[:max_completion_length]) elif len(completion_tensor) < max_completion_length: - # Pad if shorter than max_completion_length padding_needed = max_completion_length - len(completion_tensor) padded_tensor = torch.cat( [ @@ -2324,10 +2283,8 @@ def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_ ) padded_completion_ids_list.append(padded_tensor) else: - # Already the right length padded_completion_ids_list.append(completion_tensor) - # Now all tensors are the same length, so we can stack them padded_completion_ids = torch.stack(padded_completion_ids_list) # Ensure prompt_ids and padded_completion_ids are 2D @@ -2505,7 +2462,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if mode == "train": device = self.accelerator.device if hasattr(self.accelerator, "device") else torch.device("cpu") - # include matched/unmatched accumulators for distributed reduction vec = torch.tensor( [ self._on_policy_loss_total, @@ -2521,7 +2477,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: device=device, ) - # Sum across processes so we mirror Trainer's distributed reduction if ( getattr(self.accelerator, "distributed_type", DistributedType.NO) != DistributedType.NO and dist.is_available() @@ -2540,20 +2495,16 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: unmatched_eq, ) = vec.tolist() - # Compute category averages over the *same window* as Trainer's logs - # (avoid div-by-zero if, e.g., no on-policy steps in the window) if on_eq > 0: logs["on_policy_loss"] = round(on_sum / on_eq, 4) if off_eq > 0: logs["off_policy_loss"] = round(off_sum / off_eq, 4) - # matched/unmatched averaged over same logging window (if present) if matched_eq > 0: logs["matched_loss"] = round(matched_sum / matched_eq, 4) if unmatched_eq > 0: logs["unmatched_loss"] = round(unmatched_sum / unmatched_eq, 4) - # Reset window accumulators after logging (just like Trainer resets its window) self._on_policy_loss_total = self._off_policy_loss_total = 0.0 self._on_policy_step_equiv = self._off_policy_step_equiv = 0.0 self._matched_sum = self._unmatched_sum = 0.0 From c4f9a642bd775e0512c733bcd529af1ac4bd482d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Sat, 28 Feb 2026 13:24:01 +0100 Subject: [PATCH 09/32] Improve clarity of buffer implementation --- docs/source/gold_trainer.md | 4 +- trl/experimental/gold/gold_config.py | 39 ++++++--------- trl/experimental/gold/gold_trainer.py | 70 ++++++++++++++++++--------- 3 files changed, 64 insertions(+), 49 deletions(-) diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index a95e7d084a9..98af2eaa240 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -29,8 +29,8 @@ messages). Important configuration flags on [`GOLDConfig`] include: matched/unmatched loss. * `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. -* `steps_per_generation`, `num_generations`, `generation_batch_size` – control buffered rollout generation across - gradient accumulation windows, including multi-generation sampling per prompt. +* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. + `generation_batch_size` is the number of unique prompts per worker per optimizer step. A minimal end-to-end example: diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index a4a4d60be96..dfef6e735e2 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -54,14 +54,11 @@ class GOLDConfig(SFTConfig): seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). - steps_per_generation (`int` or `None`, *optional*, defaults to `None`): - Number of optimization steps per generation. If `None`, it defaults to - `gradient_accumulation_steps`. num_generations (`int`, *optional*, defaults to `1`): Number of generations per prompt. Each prompt is repeated this many times in the generation batch. generation_batch_size (`int` or `None`, *optional*, defaults to `None`): - Number of prompts per generation batch (global, across all processes). If `None`, it is computed from - `per_device_train_batch_size * world_size * steps_per_generation`. + Number of unique prompts per worker per optimizer step. If `None`, it is computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. use_uld_loss (`bool`, *optional*, defaults to `False`): Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss. @@ -187,12 +184,6 @@ class GOLDConfig(SFTConfig): "FT on teacher-generated output)." }, ) - steps_per_generation: int | None = field( - default=None, - metadata={ - "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." - }, - ) num_generations: int = field( default=1, metadata={ @@ -202,8 +193,8 @@ class GOLDConfig(SFTConfig): generation_batch_size: int | None = field( default=None, metadata={ - "help": "Number of prompts per generation batch (global, across all processes). " - "If None, computed from per_device_train_batch_size * num_processes * steps_per_generation." + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." }, ) @@ -405,25 +396,23 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) - if self.steps_per_generation is None: - self.steps_per_generation = self.gradient_accumulation_steps - + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps if self.generation_batch_size is None: - self.generation_batch_size = ( - self.per_device_train_batch_size * self.world_size * self.steps_per_generation - ) + self.generation_batch_size = local_sequence_batch_size // self.num_generations if self.num_generations < 1: raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") - if self.generation_batch_size % self.num_generations != 0: + if self.generation_batch_size < 1: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " - f"({self.num_generations})." + "generation_batch_size must be at least 1. " + f"Got generation_batch_size={self.generation_batch_size}." ) - if self.generation_batch_size // self.num_generations < 1: + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be at least num_generations " - f"({self.num_generations}) so that each generation batch contains at least one unique prompt." + "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " + f"Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, but got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps} ({local_sequence_batch_size})." ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 676130dd62a..268a1866dd0 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -234,6 +234,31 @@ def build_teacher_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length +class _RepeatEachBatchDataLoader: + """Repeats each dataloader batch `repeat_count` times without re-sampling.""" + + def __init__(self, dataloader, repeat_count: int): + if repeat_count < 1: + raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.") + self.dataloader = dataloader + self.repeat_count = repeat_count + + def __iter__(self): + for batch in self.dataloader: + for _ in range(self.repeat_count): + yield batch + + def __len__(self): + return len(self.dataloader) * self.repeat_count + + def set_epoch(self, epoch: int): + if hasattr(self.dataloader, "set_epoch"): + self.dataloader.set_epoch(epoch) + + def __getattr__(self, attr): + return getattr(self.dataloader, attr) + + class ULDLoss(nn.Module): """ Universal Logit Distillation Loss. @@ -951,7 +976,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the # final optimization step. - maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps self._textual_logs = { "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), @@ -1070,20 +1095,20 @@ def _get_train_sampler(self, dataset=None): return RepeatSampler( data_source=dataset, mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.args.steps_per_generation, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=1, shuffle=True, seed=self.args.seed, ) def get_train_dataloader(self): """ - Override Trainer.get_train_dataloader to load a generation batch covering one optimizer window. + Override Trainer.get_train_dataloader to load one generation batch per optimizer window. - Instead of returning a standard per-step batch (i.e., `per_device_batch_size`), this dataloader loads - a batch of size `per_device_batch_size * steps_per_generation`. Combined with the `RepeatSampler` - (which inflates the sampler length by `steps_per_generation`), this prevents the Trainer from - double-dividing by `gradient_accumulation_steps` when computing optimizer steps per epoch. + The base dataloader yields local batches of size + `per_device_train_batch_size * gradient_accumulation_steps`. Each base batch is then repeated + `gradient_accumulation_steps` times so Trainer still executes one mini-step per accumulation slot without + re-sampling prompts. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1096,7 +1121,7 @@ def get_train_dataloader(self): data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, @@ -1114,18 +1139,19 @@ def get_train_dataloader(self): if self.args.dataloader_num_workers > 0: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return _RepeatEachBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: if not self.model.training: return generation_batch - spg = self.args.steps_per_generation - if self._step % spg == 0 or self._buffered_inputs is None: - self._fill_buffer(generation_batch, spg) + buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) + if self._step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) - slice_idx = self._step % spg + slice_idx = self._step % buffer_steps inputs = self._buffered_inputs[slice_idx] self._step += 1 return inputs @@ -1178,20 +1204,20 @@ def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | An return updated_slice @profiling_decorator - def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], spg: int): - slices = split_tensor_dict(generation_batch, spg) + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + slices = split_tensor_dict(generation_batch, buffer_steps) if self.accelerator.is_main_process: - on_policy_flags = [random.random() <= self.lmbda for _ in range(spg)] + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] else: - on_policy_flags = [False] * spg + on_policy_flags = [False] * buffer_steps on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag] - self._buffered_inputs = [None] * spg + self._buffered_inputs = [None] * buffer_steps self._buffered_on_policy = on_policy_flags - self._buffered_text_logs = [None] * spg + self._buffered_text_logs = [None] * buffer_steps for i, flag in enumerate(on_policy_flags): if not flag: @@ -2472,12 +2498,12 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - spg = self.args.steps_per_generation + buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) loss = super().training_step(model, inputs, num_items_in_batch) - slice_idx = (self._step - 1) % spg + slice_idx = (self._step - 1) % buffer_steps on_policy = False if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy): From 111b85e72d8c20e3c29af0c88f087fd887786b4e Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 13:52:43 +0000 Subject: [PATCH 10/32] Add validation for num_generations --- trl/experimental/gold/gold_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index dfef6e735e2..828f2f3595b 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -396,12 +396,11 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps if self.generation_batch_size is None: self.generation_batch_size = local_sequence_batch_size // self.num_generations - - if self.num_generations < 1: - raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") if self.generation_batch_size < 1: raise ValueError( "generation_batch_size must be at least 1. " From 022af620b915f168bf4ec6791958a84e89193be8 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 15:32:41 +0000 Subject: [PATCH 11/32] Add clarifying comment to num_generations --- trl/experimental/gold/gold_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 828f2f3595b..c0c0b0731d5 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -187,7 +187,7 @@ class GOLDConfig(SFTConfig): num_generations: int = field( default=1, metadata={ - "help": "Number of generations per prompt. Each prompt is repeated this many times in the batch." + "help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step." }, ) generation_batch_size: int | None = field( From 33e0a82f3799cdcdd2827a10dcab9f8e0c213a27 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 15:45:41 +0000 Subject: [PATCH 12/32] Patch issue with ZeRO-3 --- trl/experimental/gold/gold_trainer.py | 28 ++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 268a1866dd0..020989d8479 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -2030,6 +2030,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_head = unwrapped_teacher.get_output_embeddings() # liger fused jsd loss + # Note: with ZeRO-3 the lm_head weights are partitioned. The + # gathering is handled in training_step() which wraps both + # forward and backward in a GatheredParameters context. loss = self.liger_jsd_loss( student_input=student_hidden, student_weight=student_head.weight, @@ -2501,7 +2504,30 @@ def training_step( buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) - loss = super().training_step(model, inputs, num_items_in_batch) + # With Liger + ZeRO-3, the lm_head weights are partitioned across + # ranks. The Liger fused kernel saves weight references during forward + # and reads them during backward; both must see the full (gathered) + # weight. Wrapping super().training_step() (which runs compute_loss + + # accelerator.backward) keeps the weights gathered for both passes. + _gather_ctx = nullcontext() + if self.use_liger_gkd_loss: + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: + import deepspeed + + unwrapped = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if getattr(student_head, "bias", None) is not None: + params.append(student_head.bias) + if getattr(teacher_head, "bias", None) is not None: + params.append(teacher_head.bias) + _gather_ctx = deepspeed.zero.GatheredParameters(params, modifier_rank=None) + + with _gather_ctx: + loss = super().training_step(model, inputs, num_items_in_batch) slice_idx = (self._step - 1) % buffer_steps From dbb6e7059c4599aa0a53ce8c1813383f06ddb788 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 16:55:24 +0100 Subject: [PATCH 13/32] Refactor context for ZeRO-3 + Liger --- trl/experimental/gold/gold_trainer.py | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 020989d8479..452cc7e32af 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -2490,6 +2490,27 @@ def _wake_vllm_if_needed(self): empty_cache() self.vllm_engine.wake_up(tags=["kv_cache"]) + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + if not self.use_liger_gkd_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + @profiling_decorator def training_step( self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None @@ -2504,29 +2525,8 @@ def training_step( buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) - # With Liger + ZeRO-3, the lm_head weights are partitioned across - # ranks. The Liger fused kernel saves weight references during forward - # and reads them during backward; both must see the full (gathered) - # weight. Wrapping super().training_step() (which runs compute_loss + - # accelerator.backward) keeps the weights gathered for both passes. - _gather_ctx = nullcontext() - if self.use_liger_gkd_loss: - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: - import deepspeed - - unwrapped = self.accelerator.unwrap_model(model) - unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) - student_head = unwrapped.get_output_embeddings() - teacher_head = unwrapped_teacher.get_output_embeddings() - params = [student_head.weight, teacher_head.weight] - if getattr(student_head, "bias", None) is not None: - params.append(student_head.bias) - if getattr(teacher_head, "bias", None) is not None: - params.append(teacher_head.bias) - _gather_ctx = deepspeed.zero.GatheredParameters(params, modifier_rank=None) - - with _gather_ctx: + # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. + with self._get_liger_zero3_lm_head_gather_ctx(model): loss = super().training_step(model, inputs, num_items_in_batch) slice_idx = (self._step - 1) % buffer_steps From 9da54b3a67b5defcc7467ab7ac1aed631476b418 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 16:05:04 +0000 Subject: [PATCH 14/32] Simplify comments and code logic --- trl/experimental/gold/gold_config.py | 11 +++++------ trl/experimental/gold/gold_trainer.py | 12 +++++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index c0c0b0731d5..8fbb42359d8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -409,15 +409,14 @@ def __post_init__(self): if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " - f"Expected generation_batch_size * num_generations == per_device_train_batch_size * " - f"gradient_accumulation_steps, but got {self.generation_batch_size} * {self.num_generations} != " - f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps} ({local_sequence_batch_size})." + "Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( - f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches will " - f"contain {self.num_generations} identical copies of each dataset sample. Consider setting " - f"lmbda=1.0 (fully on-policy) when using num_generations > 1.", + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.", UserWarning, stacklevel=2, ) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 452cc7e32af..0a83121efe6 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1106,9 +1106,8 @@ def get_train_dataloader(self): Override Trainer.get_train_dataloader to load one generation batch per optimizer window. The base dataloader yields local batches of size - `per_device_train_batch_size * gradient_accumulation_steps`. Each base batch is then repeated - `gradient_accumulation_steps` times so Trainer still executes one mini-step per accumulation slot without - re-sampling prompts. + `per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch + `gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1147,7 +1146,7 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di if not self.model.training: return generation_batch - buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) + buffer_steps = self.args.gradient_accumulation_steps if self._step % buffer_steps == 0 or self._buffered_inputs is None: self._fill_buffer(generation_batch, buffer_steps) @@ -2522,8 +2521,7 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) - ga = max(1, int(self.args.gradient_accumulation_steps)) + buffer_steps = self.args.gradient_accumulation_steps # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. with self._get_liger_zero3_lm_head_gather_ctx(model): @@ -2541,7 +2539,7 @@ def training_step( self._textual_logs["completion"].extend(gather_object(completion_texts)) loss_scalar = float(loss.detach()) - step_equiv = 1.0 / ga + step_equiv = 1.0 / self.args.gradient_accumulation_steps if on_policy: self._on_policy_loss_total += loss_scalar From 44354099e8deec523d3e9c2609aadd48ff18a437 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 16:07:56 +0000 Subject: [PATCH 15/32] Add scripts to run GOLD --- scripts/slurm/gold.slurm | 35 ++++------------------------------- trl/experimental/gold/gold.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/scripts/slurm/gold.slurm b/scripts/slurm/gold.slurm index aa20e7b659b..22713c52e9d 100755 --- a/scripts/slurm/gold.slurm +++ b/scripts/slurm/gold.slurm @@ -13,19 +13,17 @@ set -euo pipefail if [[ "$*" == *"--help"* ]]; then cat <<'EOF' -Usage: sbatch scripts/slurm/gold_buffer_test.slurm [options] +Usage: sbatch scripts/slurm/gold.slurm [options] Required: - --config PATH YAML config passed to gold_buffer_test.py + --config PATH YAML config passed to gold.py Optional: --accelerator NAME|PATH Accelerate config name (default: zero3) or explicit YAML path - --dp N vLLM server data parallel size (default: 1) - --tp N vLLM server tensor parallel size (default: 1) - --args "ARGS" Extra args appended to gold_buffer_test.py + --args "ARGS" Extra args appended to gold.py Examples: - sbatch scripts/slurm/gold_buffer_test.slurm \ + sbatch scripts/slurm/gold.slurm \ --config /path/to/config.yaml \ --accelerator zero3 \ --args "--bf16 --logging_steps 1" @@ -33,7 +31,6 @@ EOF exit 0 fi -# Cluster/environment setup (same style as internal launcher) module load cuda/12.9 || true source ~/.bashrc || true @@ -42,11 +39,8 @@ echo "START TIME: $(date)" CONFIG_FILE="" ACCELERATOR="zero3" -DP=1 -TP=1 GPUS_PER_NODE=8 OPTIONAL_ARGS="" -VENV_PATH="" while [[ $# -gt 0 ]]; do case "$1" in @@ -58,14 +52,6 @@ while [[ $# -gt 0 ]]; do ACCELERATOR="$2" shift 2 ;; - --dp) - DP="$2" - shift 2 - ;; - --tp) - TP="$2" - shift 2 - ;; --args) OPTIONAL_ARGS="$2" shift 2 @@ -100,7 +86,6 @@ if ! command -v trl >/dev/null 2>&1; then exit 1 fi -# Resolve accelerate config. if [[ -f "$ACCELERATOR" ]]; then ACCEL_CONFIG="$ACCELERATOR" elif [[ -f "trl/accelerate_configs/${ACCELERATOR}.yaml" ]]; then @@ -115,7 +100,6 @@ fi GRAD_ACC_STEPS=$(grep -E '^\s*gradient_accumulation_steps:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} -# Allow CLI override from --args. if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps=([0-9]+) ]]; then GRAD_ACC_STEPS="${BASH_REMATCH[1]}" fi @@ -123,12 +107,6 @@ if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps[[:space:]]+([0-9]+) ]]; GRAD_ACC_STEPS="${BASH_REMATCH[1]}" fi -STUDENT_MODEL=$(grep -E '^\s*model_name_or_path:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -REVISION=$(grep -E '^\s*model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -if [[ -z "${REVISION:-}" ]]; then - REVISION=$(grep -E '^\s*student_model_revision:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -fi - if [[ -z "${SLURM_JOB_NODELIST:-}" ]]; then echo "Error: this launcher must run inside a SLURM allocation." exit 1 @@ -141,11 +119,6 @@ MASTER_ADDR=${NODELIST[0]} MASTER_PORT=${MASTER_PORT:-6000} TRAIN_NODES=("${NODELIST[@]}") -USE_VLLM="false" -if grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE" && grep -qE '^\s*vllm_mode:\s*server' "$CONFIG_FILE"; then - USE_VLLM="true" -fi - NODELIST_CSV=$(IFS=,; echo "${TRAIN_NODES[*]}") export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 561a639cfe9..8787a2e40c5 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -51,6 +51,7 @@ """ import logging +import os from datasets import load_dataset from transformers import AutoTokenizer, GenerationConfig @@ -120,7 +121,6 @@ ################ # Training ################ - # Handle eval dataset - check if test split exists, fallback to validation or None eval_dataset = None if training_args.eval_strategy != "no": if script_args.dataset_test_split in dataset: @@ -130,6 +130,20 @@ elif "dev" in dataset: eval_dataset = dataset["dev"] + if isinstance(training_args.report_to, str): + report_to = {training_args.report_to} + else: + report_to = set(training_args.report_to or []) + if "wandb" in report_to: + if training_args.wandb_project is not None: + os.environ.setdefault("WANDB_PROJECT", training_args.wandb_project) + if training_args.wandb_entity is not None: + os.environ.setdefault("WANDB_ENTITY", training_args.wandb_entity) + if training_args.wandb_run_group is not None: + os.environ.setdefault("WANDB_RUN_GROUP", training_args.wandb_run_group) + if training_args.run_name is not None: + os.environ.setdefault("WANDB_NAME", training_args.run_name) + trainer = GOLDTrainer( model=model_args.model_name_or_path, teacher_model=training_args.teacher_model_name_or_path, From 31161a0ce2533055d08a40ead4bc50342f3e385c Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 18:39:29 +0100 Subject: [PATCH 16/32] Refactor to simplify logic --- trl/experimental/gold/gold_trainer.py | 247 +++++--------------------- 1 file changed, 46 insertions(+), 201 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 4fd34d1cef7..989dfdb83c1 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1197,6 +1197,26 @@ def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | An updated_slice["original_completion_text"] = completion_texts return updated_slice + @staticmethod + def _build_sequence_batch( + new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build attention mask and labels from full sequences and prompt lengths.""" + prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) + positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) + completion_mask = positions >= prompt_lengths.unsqueeze(1) + + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 + + new_labels = torch.full_like(new_input_ids, -100) + new_labels[completion_mask] = new_input_ids[completion_mask] + if pad_token_id is not None: + new_labels[new_input_ids == pad_token_id] = -100 + + return new_attention_mask, new_labels + @profiling_decorator def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): slices = split_tensor_dict(generation_batch, buffer_steps) @@ -1236,12 +1256,12 @@ def _generate_on_policy_for_slices( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] ): local_prompts = [] - local_slice_info = [] + local_slice_indices = [] for slice_idx in on_policy_indices: slice_inputs = slices[slice_idx] - for j in range(slice_inputs["prompts"].shape[0]): - local_prompts.append(slice_inputs["prompts"][j]) - local_slice_info.append((slice_idx, j)) + for prompt in slice_inputs["prompts"]: + local_prompts.append(prompt) + local_slice_indices.append(slice_idx) prompts_text_for_vllm = self.processing_class.batch_decode( torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), @@ -1294,7 +1314,7 @@ def _generate_on_policy_for_slices( self._process_completions_to_buffer( slices, on_policy_indices, - local_slice_info, + local_slice_indices, completion_ids, prompts_text_for_vllm, prompts_text_with_special, @@ -1448,13 +1468,13 @@ def _generate_non_vllm_for_slices( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] ): """Fallback generation without vLLM (uses model.generate per slice).""" - for slice_idx in on_policy_indices: - slice_inputs = slices[slice_idx] - with unwrap_model_for_generation( - self.model, - self.accelerator, - generation_kwargs=self.generation_kwargs, - ) as unwrapped_model: + with unwrap_model_for_generation( + self.model, + self.accelerator, + generation_kwargs=self.generation_kwargs, + ) as unwrapped_model: + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] result = self.generate_on_policy_outputs( unwrapped_model, slice_inputs, @@ -1463,21 +1483,21 @@ def _generate_non_vllm_for_slices( ) new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result - updated_slice = dict(slice_inputs) - updated_slice["input_ids"] = new_input_ids - updated_slice["attention_mask"] = new_attention_mask - updated_slice["labels"] = new_labels - updated_slice["original_prompt_text"] = prompt_texts - updated_slice["original_completion_text"] = completion_texts + updated_slice = dict(slice_inputs) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts - self._buffered_inputs[slice_idx] = updated_slice - self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) def _process_completions_to_buffer( self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int], - local_slice_info: list[tuple[int, int]], + local_slice_indices: list[int], completion_ids: list, prompts_text: list[str], prompts_text_with_special: list[str], @@ -1493,7 +1513,7 @@ def _process_completions_to_buffer( slice_prompts = {idx: [] for idx in on_policy_indices} slice_prompts_special = {idx: [] for idx in on_policy_indices} - for i, (slice_idx, _) in enumerate(local_slice_info): + for i, slice_idx in enumerate(local_slice_indices): slice_completions[slice_idx].append(completion_ids[i]) slice_prompts[slice_idx].append(prompts_text[i]) slice_prompts_special[slice_idx].append(prompts_text_with_special[i]) @@ -1545,17 +1565,10 @@ def _process_completions_to_buffer( completion_ids_padded = torch.stack(padded_completion_ids_list) new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) - new_attention_mask = torch.ones_like(new_input_ids) - if self.processing_class.pad_token_id is not None: - new_attention_mask[new_input_ids == self.processing_class.pad_token_id] = 0 - prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) - new_labels = torch.full_like(new_input_ids, -100) - for idx in range(new_input_ids.shape[0]): - length = int(prompt_lengths[idx].item()) - new_labels[idx, length:] = new_input_ids[idx, length:] - if self.processing_class.pad_token_id is not None: - new_labels[new_input_ids == self.processing_class.pad_token_id] = -100 + new_attention_mask, new_labels = self._build_sequence_batch( + new_input_ids, prompt_lengths, self.processing_class.pad_token_id + ) completion_texts = self.processing_class.batch_decode( completion_ids_for_text, @@ -2143,17 +2156,7 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token ) new_input_ids = generated_tokens - new_attention_mask = torch.ones_like(new_input_ids) - if pad_token_id is not None: - new_attention_mask[new_input_ids == pad_token_id] = 0 - - new_labels = torch.full_like(new_input_ids, -100) - for idx in range(batch_size): - length = int(prompt_lengths[idx].item()) - new_labels[idx, length:] = new_input_ids[idx, length:] - - if pad_token_id is not None: - new_labels[new_input_ids == pad_token_id] = -100 + new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) prompt_texts = [] completion_texts = [] @@ -2182,164 +2185,6 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token return new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts - @profiling_decorator - def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_id=None): - device = self.accelerator.device - - # Decode prompts for vLLM (without special tokens - vLLM expects clean text) - prompts_text_for_vllm = self.processing_class.batch_decode( - inputs["prompts"], - skip_special_tokens=True, - # clean_up_tokenization_spaces=False # Keep this commented unless specific issues arise - ) - # Remove padding token text if it appears, as vLLM expects clean prompts - if self.processing_class.pad_token: - prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] - - # Also decode prompts WITH special tokens for ULD loss computation - prompts_text_with_special = self.processing_class.batch_decode( - inputs["prompts"], - skip_special_tokens=False, - ) - - # system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." - # target_system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." - # prompts_text = [p.replace(target_system_prompt, system_prompt) for p in prompts_text] - # Add system prompt to prompts - - max_completion_length = generation_config.max_new_tokens - temperature = generation_config.temperature - # vLLM uses top_k=-1 for no top_k, transformers uses 0 or None. - top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1 - # top_p, repetition_penalty, min_p are not directly in generation_config, get from trainer args - top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 - repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 - min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 - - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text_for_vllm) - if self.accelerator.is_main_process: - completion_ids = self.vllm_client.generate( - prompts=all_prompts_text, - n=1, # In GKD, we generate 1 completion per prompt from student - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_completion_length, - structured_outputs_regex=self.vllm_structured_outputs_regex, - )["completion_ids"] - else: - completion_ids = [None] * len(all_prompts_text) - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * len(prompts_text_for_vllm), - (self.accelerator.process_index + 1) * len(prompts_text_for_vllm), - ) - completion_ids = completion_ids[process_slice] - elif self.vllm_mode == "colocate": - if self.vllm_structured_outputs_regex is not None: - structured_outputs = StructuredOutputsParams( - backend="outlines", regex=self.vllm_structured_outputs_regex - ) - else: - structured_outputs = None - sampling_params = SamplingParams( - n=1, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_completion_length, - structured_outputs=structured_outputs, - ) - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text_for_vllm) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text_for_vllm, group=self.vllm_tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts_text = prompts_text_for_vllm - - all_outputs = self.vllm_engine.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - else: - raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - - prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None - prompt_tokenized = self.processing_class( - prompts_text_for_vllm, - return_tensors="pt", - padding="longest", - truncation=True if prompt_max_length else False, - max_length=prompt_max_length, - add_special_tokens=False, - ).to(device) - prompt_ids = prompt_tokenized.input_ids - - completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids] - padded_completion_ids_list = [] - for completion_tensor in completion_ids_tensors: - if len(completion_tensor) > max_completion_length: - padded_completion_ids_list.append(completion_tensor[:max_completion_length]) - elif len(completion_tensor) < max_completion_length: - padding_needed = max_completion_length - len(completion_tensor) - padded_tensor = torch.cat( - [ - completion_tensor, - torch.full((padding_needed,), pad_token_id, device=device, dtype=completion_tensor.dtype), - ] - ) - padded_completion_ids_list.append(padded_tensor) - else: - padded_completion_ids_list.append(completion_tensor) - - padded_completion_ids = torch.stack(padded_completion_ids_list) - - # Ensure prompt_ids and padded_completion_ids are 2D - if prompt_ids.ndim == 1: - prompt_ids = prompt_ids.unsqueeze(0) - if padded_completion_ids.ndim == 1: - padded_completion_ids = padded_completion_ids.unsqueeze(0) - - new_input_ids = torch.cat([prompt_ids, padded_completion_ids], dim=1) - - new_attention_mask = torch.ones_like(new_input_ids, device=device) - new_labels = new_input_ids.clone() - - if pad_token_id is not None: - new_labels[new_labels == pad_token_id] = -100 - new_attention_mask[new_input_ids == pad_token_id] = 0 - - # Mask prompt tokens in labels - prompt_lengths = prompt_ids.shape[1] - new_labels[:, :prompt_lengths] = -100 - - # IMPORTANT: Preserve original text for cross-tokenizer ULD loss - # Use prompts_text_with_special (with special tokens) for ULD loss computation - # Extract completion texts from the generated completion IDs - completion_texts = [] - for comp_ids in completion_ids: - completion_text = self.processing_class.decode(comp_ids, skip_special_tokens=False) - completion_texts.append(completion_text) - - return new_input_ids, new_attention_mask, new_labels, prompts_text_with_special, completion_texts - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM.""" if visited is None: From da7ef5034e52cb72228f30616c166da5ec95daa9 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 3 Mar 2026 11:02:42 +0100 Subject: [PATCH 17/32] Handle student versioning params --- trl/experimental/gold/gold.py | 15 ++++++++++++++- trl/experimental/gold/gold_config.py | 6 +++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 8787a2e40c5..bd6988c6046 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -79,9 +79,22 @@ ################ # Model & Tokenizer ################ + if training_args.student_model_revision is None: + training_args.student_model_revision = model_args.model_revision + elif ( + model_args.model_revision is not None + and training_args.student_model_revision != model_args.model_revision + ): + raise ValueError( + "Conflicting revisions for student model: " + f"student_model_revision={training_args.student_model_revision!r} and " + f"model_revision={model_args.model_revision!r}. " + "Set only one revision, or set both to the same value." + ) + quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=model_args.model_revision, + revision=training_args.student_model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 8fbb42359d8..609d58ba853 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -146,10 +146,10 @@ class GOLDConfig(SFTConfig): default=128, metadata={"help": "Maximum number of tokens to generate per completion."}, ) - student_model_revision: str = field( - default="main", + student_model_revision: str | None = field( + default=None, metadata={ - "help": "Revision of the student model to use. If not specified, the default revision of the model will be used." + "help": "Revision of the student model to use. If not specified, `model_revision` is used." }, ) teacher_model_name_or_path: str | None = field( From e24e68178cdf9d31b7cbf4374c9a46476e1bb561 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 3 Mar 2026 11:03:08 +0100 Subject: [PATCH 18/32] Add warning when dropping incomplete batches --- trl/experimental/gold/gold_trainer.py | 36 +++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 989dfdb83c1..8df662c5c85 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -236,14 +236,42 @@ def build_teacher_inputs_from_texts( class _RepeatEachBatchDataLoader: """Repeats each dataloader batch `repeat_count` times without re-sampling.""" - def __init__(self, dataloader, repeat_count: int): + @staticmethod + def _infer_batch_size(batch: dict[str, torch.Tensor | Any]) -> int | None: + for value in batch.values(): + if value is None: + continue + if isinstance(value, torch.Tensor): + if value.ndim == 0: + continue + return int(value.shape[0]) + if isinstance(value, (list, tuple)): + return len(value) + return None + + def __init__(self, dataloader, repeat_count: int, expected_batch_size: int | None = None): if repeat_count < 1: raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.") self.dataloader = dataloader self.repeat_count = repeat_count + self.expected_batch_size = expected_batch_size + self._dropped_partial_batch = False def __iter__(self): + self._dropped_partial_batch = False for batch in self.dataloader: + if self.expected_batch_size is not None: + batch_size = self._infer_batch_size(batch) + if batch_size is not None and batch_size != self.expected_batch_size: + if not self._dropped_partial_batch: + warnings.warn( + "Dropping last batch due to unexpected batch size: " + f"got {batch_size}, expected {self.expected_batch_size}.", + UserWarning, + stacklevel=2, + ) + self._dropped_partial_batch = True + break for _ in range(self.repeat_count): yield batch @@ -1138,7 +1166,11 @@ def get_train_dataloader(self): dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - return _RepeatEachBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) + return _RepeatEachBatchDataLoader( + base_dataloader, + repeat_count=self.args.gradient_accumulation_steps, + expected_batch_size=self._train_batch_size * self.args.gradient_accumulation_steps, + ) @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: From 8d31b7aba124bd9bc84ffed4b29b7b5b9c009050 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 3 Mar 2026 15:11:46 +0100 Subject: [PATCH 19/32] Add clarifying note in docs --- docs/source/gold_trainer.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index 98af2eaa240..e72a254545a 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -31,6 +31,8 @@ messages). Important configuration flags on [`GOLDConfig`] include: sampling ratio. * `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. `generation_batch_size` is the number of unique prompts per worker per optimizer step. +* `student_model_revision` and `model_revision` – if `student_model_revision` is unset, GOLD uses `model_revision`. + If both are set and differ, GOLD raises an error to avoid loading different revisions for training vs generation. A minimal end-to-end example: @@ -81,7 +83,7 @@ train_dataset = load_dataset( training_args = GOLDConfig( output_dir="gold-model", per_device_train_batch_size=1, - teacher_model=teacher_name, + teacher_model_name_or_path=teacher_name, teacher_tokenizer_name_or_path=teacher_name, use_uld_loss=True, uld_use_hybrid_loss=True, @@ -97,6 +99,11 @@ trainer = GOLDTrainer( trainer.train() ``` +> [!NOTE] +> GOLD buffers one full optimizer-window generation batch (`per_device_train_batch_size * gradient_accumulation_steps`) +> and reuses it across accumulation steps. If the final batch is undersized, GOLD warns and drops that last batch +> (`Dropping last batch due to unexpected batch size`). Set `dataloader_drop_last=True` to avoid this warning. + ### Expected dataset type GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.: From 1ef205b74aad4a4b32e07fb6af04aa2abd2bce10 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 3 Mar 2026 14:17:02 +0000 Subject: [PATCH 20/32] Remove SLURM script used for testing --- scripts/slurm/gold.slurm | 157 --------------------------------------- 1 file changed, 157 deletions(-) delete mode 100755 scripts/slurm/gold.slurm diff --git a/scripts/slurm/gold.slurm b/scripts/slurm/gold.slurm deleted file mode 100755 index 22713c52e9d..00000000000 --- a/scripts/slurm/gold.slurm +++ /dev/null @@ -1,157 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=trl-gold-buffer-test -#SBATCH --ntasks-per-node=1 -#SBATCH --exclusive -#SBATCH --gres=gpu:8 -#SBATCH --partition=hopper-prod -#SBATCH --output=logs/%x-%j.out -#SBATCH --error=logs/%x-%j.err -#SBATCH --requeue -#SBATCH --time=0-12:00:00 - -set -euo pipefail - -if [[ "$*" == *"--help"* ]]; then - cat <<'EOF' -Usage: sbatch scripts/slurm/gold.slurm [options] - -Required: - --config PATH YAML config passed to gold.py - -Optional: - --accelerator NAME|PATH Accelerate config name (default: zero3) or explicit YAML path - --args "ARGS" Extra args appended to gold.py - -Examples: - sbatch scripts/slurm/gold.slurm \ - --config /path/to/config.yaml \ - --accelerator zero3 \ - --args "--bf16 --logging_steps 1" -EOF - exit 0 -fi - -module load cuda/12.9 || true -source ~/.bashrc || true - -START_TIME=$(date +%s) -echo "START TIME: $(date)" - -CONFIG_FILE="" -ACCELERATOR="zero3" -GPUS_PER_NODE=8 -OPTIONAL_ARGS="" - -while [[ $# -gt 0 ]]; do - case "$1" in - --config) - CONFIG_FILE="$2" - shift 2 - ;; - --accelerator) - ACCELERATOR="$2" - shift 2 - ;; - --args) - OPTIONAL_ARGS="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" - echo "Run with --help for usage." - exit 1 - ;; - esac -done - -source "trl-internal/bin/activate" - -if [[ -z "$CONFIG_FILE" ]]; then - echo "Error: --config is required." - exit 1 -fi - -if [[ ! -f "$CONFIG_FILE" ]]; then - echo "Error: config file not found: $CONFIG_FILE" - exit 1 -fi - -if ! command -v accelerate >/dev/null 2>&1; then - echo "Error: accelerate is not available in PATH." - exit 1 -fi - -if ! command -v trl >/dev/null 2>&1; then - echo "Error: trl CLI is not available in PATH." - exit 1 -fi - -if [[ -f "$ACCELERATOR" ]]; then - ACCEL_CONFIG="$ACCELERATOR" -elif [[ -f "trl/accelerate_configs/${ACCELERATOR}.yaml" ]]; then - ACCEL_CONFIG="trl/accelerate_configs/${ACCELERATOR}.yaml" -elif [[ -f "examples/accelerate_configs/${ACCELERATOR}.yaml" ]]; then - ACCEL_CONFIG="examples/accelerate_configs/${ACCELERATOR}.yaml" -else - echo "Error: could not resolve accelerate config from '$ACCELERATOR'." - exit 1 -fi - -GRAD_ACC_STEPS=$(grep -E '^\s*gradient_accumulation_steps:' "$CONFIG_FILE" | head -n 1 | awk '{print $2}') -GRAD_ACC_STEPS=${GRAD_ACC_STEPS:-1} - -if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps=([0-9]+) ]]; then - GRAD_ACC_STEPS="${BASH_REMATCH[1]}" -fi -if [[ "$OPTIONAL_ARGS" =~ --gradient_accumulation_steps[[:space:]]+([0-9]+) ]]; then - GRAD_ACC_STEPS="${BASH_REMATCH[1]}" -fi - -if [[ -z "${SLURM_JOB_NODELIST:-}" ]]; then - echo "Error: this launcher must run inside a SLURM allocation." - exit 1 -fi - -NUM_NODES=${SLURM_NNODES:-1} -WORLD_SIZE=$((NUM_NODES * GPUS_PER_NODE)) -NODELIST=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) -MASTER_ADDR=${NODELIST[0]} -MASTER_PORT=${MASTER_PORT:-6000} -TRAIN_NODES=("${NODELIST[@]}") - -NODELIST_CSV=$(IFS=,; echo "${TRAIN_NODES[*]}") - -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 - -SCRIPT_PATH="trl/experimental/gold/gold.py" -LAUNCH_CMD="ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \ - --config_file $ACCEL_CONFIG \ - --gradient_accumulation_steps $GRAD_ACC_STEPS \ - --num_machines $NUM_NODES \ - --num_processes $WORLD_SIZE \ - --main_process_ip $MASTER_ADDR \ - --main_process_port $MASTER_PORT \ - --machine_rank \$SLURM_PROCID \ - --rdzv_backend=c10d \ - --max_restarts 1 \ - --tee 3 \ - $SCRIPT_PATH --config $CONFIG_FILE $OPTIONAL_ARGS" - -SRUN_ARGS=" \ - --wait=60 \ - --kill-on-bad-exit=1 \ - --nodes=$NUM_NODES \ - --ntasks=$NUM_NODES \ - --nodelist=$NODELIST_CSV" - -set -x -clear -srun $SRUN_ARGS bash -lc "$LAUNCH_CMD" 2>&1 - -END_TIME=$(date +%s) -echo "END TIME: $(date)" -ELAPSED_SECONDS=$((END_TIME - START_TIME)) -HOURS=$((ELAPSED_SECONDS / 3600)) -MINUTES=$(((ELAPSED_SECONDS % 3600) / 60)) -SECONDS=$((ELAPSED_SECONDS % 60)) -echo "TOTAL JOB TIME: ${HOURS}h ${MINUTES}m ${SECONDS}s (${ELAPSED_SECONDS} seconds)" From 506afc16c1d77f13d4bbfd264ca0de4defd95d08 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 3 Mar 2026 15:16:00 +0000 Subject: [PATCH 21/32] Remove reference to wandb --- trl/experimental/gold/gold.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index bd6988c6046..672ae82560f 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -143,20 +143,6 @@ elif "dev" in dataset: eval_dataset = dataset["dev"] - if isinstance(training_args.report_to, str): - report_to = {training_args.report_to} - else: - report_to = set(training_args.report_to or []) - if "wandb" in report_to: - if training_args.wandb_project is not None: - os.environ.setdefault("WANDB_PROJECT", training_args.wandb_project) - if training_args.wandb_entity is not None: - os.environ.setdefault("WANDB_ENTITY", training_args.wandb_entity) - if training_args.wandb_run_group is not None: - os.environ.setdefault("WANDB_RUN_GROUP", training_args.wandb_run_group) - if training_args.run_name is not None: - os.environ.setdefault("WANDB_NAME", training_args.run_name) - trainer = GOLDTrainer( model=model_args.model_name_or_path, teacher_model=training_args.teacher_model_name_or_path, From 98ec20cfde20a0396ce445d3e9d2172bc2765539 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 5 Mar 2026 18:57:32 +0100 Subject: [PATCH 22/32] Remove `_RepeatEachBatchDataLoader` to simplify codebase --- trl/experimental/gold/gold_trainer.py | 68 +++------------------------ 1 file changed, 6 insertions(+), 62 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 8df662c5c85..f6c8c171848 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -233,58 +233,6 @@ def build_teacher_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length -class _RepeatEachBatchDataLoader: - """Repeats each dataloader batch `repeat_count` times without re-sampling.""" - - @staticmethod - def _infer_batch_size(batch: dict[str, torch.Tensor | Any]) -> int | None: - for value in batch.values(): - if value is None: - continue - if isinstance(value, torch.Tensor): - if value.ndim == 0: - continue - return int(value.shape[0]) - if isinstance(value, (list, tuple)): - return len(value) - return None - - def __init__(self, dataloader, repeat_count: int, expected_batch_size: int | None = None): - if repeat_count < 1: - raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.") - self.dataloader = dataloader - self.repeat_count = repeat_count - self.expected_batch_size = expected_batch_size - self._dropped_partial_batch = False - - def __iter__(self): - self._dropped_partial_batch = False - for batch in self.dataloader: - if self.expected_batch_size is not None: - batch_size = self._infer_batch_size(batch) - if batch_size is not None and batch_size != self.expected_batch_size: - if not self._dropped_partial_batch: - warnings.warn( - "Dropping last batch due to unexpected batch size: " - f"got {batch_size}, expected {self.expected_batch_size}.", - UserWarning, - stacklevel=2, - ) - self._dropped_partial_batch = True - break - for _ in range(self.repeat_count): - yield batch - - def __len__(self): - return len(self.dataloader) * self.repeat_count - - def set_epoch(self, epoch: int): - if hasattr(self.dataloader, "set_epoch"): - self.dataloader.set_epoch(epoch) - - def __getattr__(self, attr): - return getattr(self.dataloader, attr) - class ULDLoss(nn.Module): """ @@ -1123,7 +1071,7 @@ def _get_train_sampler(self, dataset=None): data_source=dataset, mini_repeat_count=self.num_generations, batch_size=self.args.generation_batch_size * self.accelerator.num_processes, - repeat_count=1, + repeat_count=self.args.gradient_accumulation_steps, shuffle=True, seed=self.args.seed, ) @@ -1132,9 +1080,10 @@ def get_train_dataloader(self): """ Override Trainer.get_train_dataloader to load one generation batch per optimizer window. - The base dataloader yields local batches of size - `per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch - `gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts. + The dataloader yields local batches of size `per_device_train_batch_size * gradient_accumulation_steps`. + The `RepeatSampler` (with `repeat_count=gradient_accumulation_steps`) ensures each generation batch is + sampled `gradient_accumulation_steps` times so Trainer's loop iterates the correct number of times. + Only the first batch in each window triggers `_fill_buffer`; the rest are ignored by `_prepare_inputs`. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1165,12 +1114,7 @@ def get_train_dataloader(self): if self.args.dataloader_num_workers > 0: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) - return _RepeatEachBatchDataLoader( - base_dataloader, - repeat_count=self.args.gradient_accumulation_steps, - expected_batch_size=self._train_batch_size * self.args.gradient_accumulation_steps, - ) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: From da57e470f765788555d34ecc59986ecb03a93e18 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 13:35:45 +0000 Subject: [PATCH 23/32] Remove support for student_model_revision arg --- trl/experimental/gold/gold.py | 16 +--------------- trl/experimental/gold/gold_config.py | 6 ------ trl/experimental/gold/gold_trainer.py | 5 +---- 3 files changed, 2 insertions(+), 25 deletions(-) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 672ae82560f..2cfa64a329b 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -51,7 +51,6 @@ """ import logging -import os from datasets import load_dataset from transformers import AutoTokenizer, GenerationConfig @@ -79,22 +78,9 @@ ################ # Model & Tokenizer ################ - if training_args.student_model_revision is None: - training_args.student_model_revision = model_args.model_revision - elif ( - model_args.model_revision is not None - and training_args.student_model_revision != model_args.model_revision - ): - raise ValueError( - "Conflicting revisions for student model: " - f"student_model_revision={training_args.student_model_revision!r} and " - f"model_revision={model_args.model_revision!r}. " - "Set only one revision, or set both to the same value." - ) - quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=training_args.student_model_revision, + revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 609d58ba853..676ab1e2e85 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -146,12 +146,6 @@ class GOLDConfig(SFTConfig): default=128, metadata={"help": "Maximum number of tokens to generate per completion."}, ) - student_model_revision: str | None = field( - default=None, - metadata={ - "help": "Revision of the student model to use. If not specified, `model_revision` is used." - }, - ) teacher_model_name_or_path: str | None = field( default=None, metadata={ diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index f6c8c171848..02d5db8908d 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -807,10 +807,7 @@ def __init__( peft_config: Optional["PeftConfig"] = None, ): self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path - self.model_revision = getattr(args, "student_model_revision", None) - if isinstance(model, str) and self.model_revision is not None: - args.model_init_kwargs = args.model_init_kwargs or {} - args.model_init_kwargs.setdefault("revision", self.model_revision) + self.model_revision = (args.model_init_kwargs or {}).get("revision") # Respect a user-provided data_collator; otherwise, provide a ChatML collator that if data_collator is None: From c3a8d73f45a28f347bac8dcc17ec4b35dc85de2d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 14:43:16 +0100 Subject: [PATCH 24/32] Fix prompt length calculation --- trl/experimental/gold/gold_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 02d5db8908d..6366ca13b7c 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1538,7 +1538,7 @@ def _process_completions_to_buffer( completion_ids_padded = torch.stack(padded_completion_ids_list) new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) - prompt_lengths = (prompt_ids != pad_token_id).sum(dim=1) + prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) new_attention_mask, new_labels = self._build_sequence_batch( new_input_ids, prompt_lengths, self.processing_class.pad_token_id ) From d1857160f80f69ad5047de980671fa2f02114358 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 14:52:48 +0100 Subject: [PATCH 25/32] Fix logic of padding tokens and prompt lengths --- trl/experimental/gold/gold_trainer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 6366ca13b7c..ded6765748e 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1540,7 +1540,7 @@ def _process_completions_to_buffer( new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) new_attention_mask, new_labels = self._build_sequence_batch( - new_input_ids, prompt_lengths, self.processing_class.pad_token_id + new_input_ids, prompt_lengths, pad_token_id ) completion_texts = self.processing_class.batch_decode( @@ -2115,18 +2115,14 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token prompt_mask = inputs.get("prompt_attention_mask") pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id - if prompt_mask is not None: - prompt_lengths = prompt_mask.sum(dim=1).to(torch.long) - else: - if pad_token_id is not None: - prompt_lengths = (inputs["prompts"] != pad_token_id).sum(dim=1).to(torch.long) - else: - prompt_lengths = torch.full( - (batch_size,), - inputs["prompts"].shape[1], - dtype=torch.long, - device=device, - ) + # Use the full padded prompt width for label masking, since model.generate() returns + # sequences where completions start after the full prompt tensor (including padding). + prompt_lengths = torch.full( + (batch_size,), + inputs["prompts"].shape[1], + dtype=torch.long, + device=device, + ) new_input_ids = generated_tokens new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) From 30a0fd564d9b341061491b64aa32608df2f4e9eb Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 15:07:37 +0100 Subject: [PATCH 26/32] Add `teacher_model_revision` arg --- trl/experimental/gold/gold.py | 1 + trl/experimental/gold/gold_config.py | 10 ++++++++++ trl/experimental/gold/gold_trainer.py | 2 ++ 3 files changed, 13 insertions(+) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 2cfa64a329b..399344eab20 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -93,6 +93,7 @@ if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss: training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path teacher_model_kwargs = dict( + revision=training_args.teacher_model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 676ab1e2e85..38b15be66be 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -43,6 +43,9 @@ class GOLDConfig(SFTConfig): teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained. + teacher_model_revision (`str` or `None`, *optional*, defaults to `None`): + Model revision of the teacher model (e.g., branch name, tag, or commit hash). If `None`, the default + revision is used. teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. @@ -153,6 +156,13 @@ class GOLDConfig(SFTConfig): "model being trained." }, ) + teacher_model_revision: str | None = field( + default=None, + metadata={ + "help": "Model revision of the teacher model (e.g., branch name, tag, or commit hash). If `None`, the " + "default revision is used." + }, + ) teacher_model_init_kwargs: dict[str, Any] | None = field( default=None, metadata={ diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index ded6765748e..e2496de3281 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -850,6 +850,8 @@ def __init__( if isinstance(teacher_model, str): init_kwargs = dict(teacher_model_init_kwargs) + if args.teacher_model_revision is not None: + init_kwargs.setdefault("revision", args.teacher_model_revision) if "torch_dtype" in init_kwargs and "dtype" not in init_kwargs: init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") teacher_model = create_model_from_path(teacher_model, **init_kwargs) From 58b9f741764c736b4dc852635ca7f741c1c7a63d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 16:12:10 +0100 Subject: [PATCH 27/32] Avoid creating padding gaps --- tests/experimental/test_gold_trainer.py | 82 +++++++++++++++++++++++++ trl/experimental/gold/gold_trainer.py | 1 + 2 files changed, 83 insertions(+) diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index f48f3464654..a6940e32b1e 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -289,6 +289,88 @@ def pad_labels(labels, target_length): return labels + [-100] * (target_length - len(labels)) +def test_process_completions_to_buffer_left_pads_prompt_retokenization(): + class DummyBatch: + def __init__(self, input_ids): + self.input_ids = input_ids + + def to(self, device): + self.input_ids = self.input_ids.to(device) + return self + + class RecordingTokenizer: + pad_token_id = 0 + pad_token = "" + + def __init__(self): + self.padding_side = "right" + self.calls = [] + self._prompt_ids = { + "short": [11], + "longer": [21, 22], + } + + def __call__( + self, + texts, + return_tensors, + padding, + truncation, + max_length, + add_special_tokens, + padding_side=None, + ): + assert return_tensors == "pt" + assert padding == "longest" + assert not truncation + assert max_length is None + assert not add_special_tokens + self.calls.append(padding_side) + + side = padding_side or self.padding_side + encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts] + max_len = max(len(ids) for ids in encoded) + + padded = [] + for ids in encoded: + pad_width = max_len - len(ids) + if pad_width: + pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long) + ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad]) + padded.append(ids) + + return DummyBatch(torch.stack(padded)) + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del skip_special_tokens, clean_up_tokenization_spaces + return [" ".join(str(token) for token in sequence) for sequence in sequences] + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(device=torch.device("cpu")) + trainer.processing_class = RecordingTokenizer() + trainer.args = SimpleNamespace(max_length=None) + trainer._buffered_inputs = [None] + trainer._buffered_text_logs = [None] + + GOLDTrainer._process_completions_to_buffer( + trainer, + slices=[{"slice": "original"}], + on_policy_indices=[0], + local_slice_indices=[0, 0], + completion_ids=[[31], [41]], + prompts_text=["short", "longer"], + prompts_text_with_special=["short", "longer"], + max_completion_length=1, + ) + + buffered_inputs = trainer._buffered_inputs[0] + assert trainer.processing_class.calls == ["left"] + assert trainer.processing_class.padding_side == "right" + assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long)) + assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long)) + assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]])) + + def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer): config = build_config() loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index e2496de3281..b6a3ad535b2 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1504,6 +1504,7 @@ def _process_completions_to_buffer( prompt_txts, return_tensors="pt", padding="longest", + padding_side="left", truncation=True if prompt_max_length else False, max_length=prompt_max_length, add_special_tokens=False, From ea72770d6d941b7e3c4560cdf569984082965ead Mon Sep 17 00:00:00 2001 From: cmpatino Date: Thu, 12 Mar 2026 17:39:13 +0100 Subject: [PATCH 28/32] Fix prompt completion calculation for transformers --- trl/experimental/gold/gold_trainer.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index b6a3ad535b2..a243a6513f3 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -2118,14 +2118,18 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token prompt_mask = inputs.get("prompt_attention_mask") pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id - # Use the full padded prompt width for label masking, since model.generate() returns - # sequences where completions start after the full prompt tensor (including padding). - prompt_lengths = torch.full( - (batch_size,), - inputs["prompts"].shape[1], - dtype=torch.long, - device=device, - ) + if self.use_transformers_paged: + # generate_batch() returns completion-only tokens, so the entire tensor is completion. + prompt_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) + else: + # model.generate() returns full sequences (prompt + completion), so completions start + # after the full padded prompt width. + prompt_lengths = torch.full( + (batch_size,), + inputs["prompts"].shape[1], + dtype=torch.long, + device=device, + ) new_input_ids = generated_tokens new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) From 6fea90b12c1811dfe1afcbbf82552da7ad7a982b Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 17 Mar 2026 11:02:35 +0000 Subject: [PATCH 29/32] Lint files with precommit --- trl/experimental/gold/gold_config.py | 3 +-- trl/experimental/gold/gold_trainer.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index f479649cae9..d550abd5b57 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -405,8 +405,7 @@ def __post_init__(self): self.generation_batch_size = local_sequence_batch_size // self.num_generations if self.generation_batch_size < 1: raise ValueError( - "generation_batch_size must be at least 1. " - f"Got generation_batch_size={self.generation_batch_size}." + f"generation_batch_size must be at least 1. Got generation_batch_size={self.generation_batch_size}." ) if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index a243a6513f3..cc4f3f2553a 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -233,7 +233,6 @@ def build_teacher_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length - class ULDLoss(nn.Module): """ Universal Logit Distillation Loss. @@ -1149,7 +1148,9 @@ def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Ten clean_up_tokenization_spaces=False, ) - def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + def _ensure_original_text_fields( + self, slice_inputs: dict[str, torch.Tensor | Any] + ) -> dict[str, torch.Tensor | Any]: """Populate original prompt/completion text fields when missing.""" if "original_prompt_text" in slice_inputs and "original_completion_text" in slice_inputs: return slice_inputs @@ -1255,7 +1256,9 @@ def _generate_on_policy_for_slices( max_completion_length = self.generation_config.max_new_tokens temperature = self.generation_config.temperature - top_k = self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 + top_k = ( + self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 + ) top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 @@ -1439,9 +1442,7 @@ def _generate_vllm_colocate( return completion_ids - def _generate_non_vllm_for_slices( - self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] - ): + def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): """Fallback generation without vLLM (uses model.generate per slice).""" with unwrap_model_for_generation( self.model, @@ -1542,9 +1543,7 @@ def _process_completions_to_buffer( new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) - new_attention_mask, new_labels = self._build_sequence_batch( - new_input_ids, prompt_lengths, pad_token_id - ) + new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) completion_texts = self.processing_class.batch_decode( completion_ids_for_text, From bfa74064061c1dbc6a5003d05b2db3469765d7cf Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 17 Mar 2026 11:03:14 +0000 Subject: [PATCH 30/32] Remove reference to `student_model_revision` --- docs/source/gold_trainer.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index e72a254545a..72685f7ae80 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -31,8 +31,7 @@ messages). Important configuration flags on [`GOLDConfig`] include: sampling ratio. * `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. `generation_batch_size` is the number of unique prompts per worker per optimizer step. -* `student_model_revision` and `model_revision` – if `student_model_revision` is unset, GOLD uses `model_revision`. - If both are set and differ, GOLD raises an error to avoid loading different revisions for training vs generation. +* `model_revision` – controls which student model revision GOLD loads for training and generation. A minimal end-to-end example: From d4d5ae293d55fcb2aa7abadf22b3167b7a7156ae Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 17 Mar 2026 11:04:32 +0000 Subject: [PATCH 31/32] Remove duplicated arg in config --- trl/experimental/gold/gold_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index d550abd5b57..2a0984e6892 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -379,7 +379,6 @@ class GOLDConfig(SFTConfig): hub_model_revision: str | None = field( default="main", metadata={"help": "The Hub model branch to push the model to."} ) - num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) From abfadc13ea8266503a6feb06a2b65c5fa453b5f9 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Tue, 17 Mar 2026 11:38:53 +0000 Subject: [PATCH 32/32] Update test to reflect full generated output from transformers --- tests/experimental/test_gold_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index a6940e32b1e..50800c0a136 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -464,7 +464,9 @@ def test_generate_on_policy_outputs_masks_prompt(llama_tokenizer): prompt_tensor[0, pad_width:] = torch.tensor(prompt_ids, dtype=torch.long) prompt_mask = (prompt_tensor != pad_id).long() - generated_sequence = torch.tensor(prompt_ids + completion_ids, dtype=torch.long).unsqueeze(0) + # model.generate() returns full sequences including left-padding from the input + completion_tensor = torch.tensor(completion_ids, dtype=torch.long).unsqueeze(0) + generated_sequence = torch.cat([prompt_tensor, completion_tensor], dim=1) class DummyModel: def generate(self, input_ids, attention_mask, generation_config, return_dict_in_generate): @@ -488,9 +490,9 @@ def generate(self, input_ids, attention_mask, generation_config, return_dict_in_ else: assert torch.all(new_mask == 1) - prompt_len = len(prompt_ids) - assert torch.all(new_labels[0, :prompt_len] == -100) - assert torch.equal(new_labels[0, prompt_len:], torch.tensor(completion_ids, dtype=torch.long)) + padded_prompt_len = prompt_tensor.shape[1] + assert torch.all(new_labels[0, :padded_prompt_len] == -100) + assert torch.equal(new_labels[0, padded_prompt_len:], torch.tensor(completion_ids, dtype=torch.long)) assert prompt_texts[0] == llama_tokenizer.decode(prompt_ids, skip_special_tokens=False) assert completion_texts[0] == llama_tokenizer.decode(completion_ids, skip_special_tokens=False)