From c4f9a642bd775e0512c733bcd529af1ac4bd482d Mon Sep 17 00:00:00 2001 From: cmpatino Date: Sat, 28 Feb 2026 13:24:01 +0100 Subject: [PATCH 1/6] Improve clarity of buffer implementation --- docs/source/gold_trainer.md | 4 +- trl/experimental/gold/gold_config.py | 39 ++++++--------- trl/experimental/gold/gold_trainer.py | 70 ++++++++++++++++++--------- 3 files changed, 64 insertions(+), 49 deletions(-) diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index a95e7d084a9..98af2eaa240 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -29,8 +29,8 @@ messages). Important configuration flags on [`GOLDConfig`] include: matched/unmatched loss. * `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. -* `steps_per_generation`, `num_generations`, `generation_batch_size` – control buffered rollout generation across - gradient accumulation windows, including multi-generation sampling per prompt. +* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. + `generation_batch_size` is the number of unique prompts per worker per optimizer step. A minimal end-to-end example: diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index a4a4d60be96..dfef6e735e2 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -54,14 +54,11 @@ class GOLDConfig(SFTConfig): seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). - steps_per_generation (`int` or `None`, *optional*, defaults to `None`): - Number of optimization steps per generation. If `None`, it defaults to - `gradient_accumulation_steps`. num_generations (`int`, *optional*, defaults to `1`): Number of generations per prompt. Each prompt is repeated this many times in the generation batch. generation_batch_size (`int` or `None`, *optional*, defaults to `None`): - Number of prompts per generation batch (global, across all processes). If `None`, it is computed from - `per_device_train_batch_size * world_size * steps_per_generation`. + Number of unique prompts per worker per optimizer step. If `None`, it is computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. use_uld_loss (`bool`, *optional*, defaults to `False`): Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss. @@ -187,12 +184,6 @@ class GOLDConfig(SFTConfig): "FT on teacher-generated output)." }, ) - steps_per_generation: int | None = field( - default=None, - metadata={ - "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." - }, - ) num_generations: int = field( default=1, metadata={ @@ -202,8 +193,8 @@ class GOLDConfig(SFTConfig): generation_batch_size: int | None = field( default=None, metadata={ - "help": "Number of prompts per generation batch (global, across all processes). " - "If None, computed from per_device_train_batch_size * num_processes * steps_per_generation." + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." }, ) @@ -405,25 +396,23 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) - if self.steps_per_generation is None: - self.steps_per_generation = self.gradient_accumulation_steps - + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps if self.generation_batch_size is None: - self.generation_batch_size = ( - self.per_device_train_batch_size * self.world_size * self.steps_per_generation - ) + self.generation_batch_size = local_sequence_batch_size // self.num_generations if self.num_generations < 1: raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") - if self.generation_batch_size % self.num_generations != 0: + if self.generation_batch_size < 1: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " - f"({self.num_generations})." + "generation_batch_size must be at least 1. " + f"Got generation_batch_size={self.generation_batch_size}." ) - if self.generation_batch_size // self.num_generations < 1: + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be at least num_generations " - f"({self.num_generations}) so that each generation batch contains at least one unique prompt." + "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " + f"Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, but got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps} ({local_sequence_batch_size})." ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 676130dd62a..268a1866dd0 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -234,6 +234,31 @@ def build_teacher_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length +class _RepeatEachBatchDataLoader: + """Repeats each dataloader batch `repeat_count` times without re-sampling.""" + + def __init__(self, dataloader, repeat_count: int): + if repeat_count < 1: + raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.") + self.dataloader = dataloader + self.repeat_count = repeat_count + + def __iter__(self): + for batch in self.dataloader: + for _ in range(self.repeat_count): + yield batch + + def __len__(self): + return len(self.dataloader) * self.repeat_count + + def set_epoch(self, epoch: int): + if hasattr(self.dataloader, "set_epoch"): + self.dataloader.set_epoch(epoch) + + def __getattr__(self, attr): + return getattr(self.dataloader, attr) + + class ULDLoss(nn.Module): """ Universal Logit Distillation Loss. @@ -951,7 +976,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the # final optimization step. - maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps self._textual_logs = { "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), @@ -1070,20 +1095,20 @@ def _get_train_sampler(self, dataset=None): return RepeatSampler( data_source=dataset, mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.args.steps_per_generation, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=1, shuffle=True, seed=self.args.seed, ) def get_train_dataloader(self): """ - Override Trainer.get_train_dataloader to load a generation batch covering one optimizer window. + Override Trainer.get_train_dataloader to load one generation batch per optimizer window. - Instead of returning a standard per-step batch (i.e., `per_device_batch_size`), this dataloader loads - a batch of size `per_device_batch_size * steps_per_generation`. Combined with the `RepeatSampler` - (which inflates the sampler length by `steps_per_generation`), this prevents the Trainer from - double-dividing by `gradient_accumulation_steps` when computing optimizer steps per epoch. + The base dataloader yields local batches of size + `per_device_train_batch_size * gradient_accumulation_steps`. Each base batch is then repeated + `gradient_accumulation_steps` times so Trainer still executes one mini-step per accumulation slot without + re-sampling prompts. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1096,7 +1121,7 @@ def get_train_dataloader(self): data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, @@ -1114,18 +1139,19 @@ def get_train_dataloader(self): if self.args.dataloader_num_workers > 0: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return _RepeatEachBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: if not self.model.training: return generation_batch - spg = self.args.steps_per_generation - if self._step % spg == 0 or self._buffered_inputs is None: - self._fill_buffer(generation_batch, spg) + buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) + if self._step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) - slice_idx = self._step % spg + slice_idx = self._step % buffer_steps inputs = self._buffered_inputs[slice_idx] self._step += 1 return inputs @@ -1178,20 +1204,20 @@ def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | An return updated_slice @profiling_decorator - def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], spg: int): - slices = split_tensor_dict(generation_batch, spg) + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + slices = split_tensor_dict(generation_batch, buffer_steps) if self.accelerator.is_main_process: - on_policy_flags = [random.random() <= self.lmbda for _ in range(spg)] + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] else: - on_policy_flags = [False] * spg + on_policy_flags = [False] * buffer_steps on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag] - self._buffered_inputs = [None] * spg + self._buffered_inputs = [None] * buffer_steps self._buffered_on_policy = on_policy_flags - self._buffered_text_logs = [None] * spg + self._buffered_text_logs = [None] * buffer_steps for i, flag in enumerate(on_policy_flags): if not flag: @@ -2472,12 +2498,12 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - spg = self.args.steps_per_generation + buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) loss = super().training_step(model, inputs, num_items_in_batch) - slice_idx = (self._step - 1) % spg + slice_idx = (self._step - 1) % buffer_steps on_policy = False if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy): From 111b85e72d8c20e3c29af0c88f087fd887786b4e Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 13:52:43 +0000 Subject: [PATCH 2/6] Add validation for num_generations --- trl/experimental/gold/gold_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index dfef6e735e2..828f2f3595b 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -396,12 +396,11 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps if self.generation_batch_size is None: self.generation_batch_size = local_sequence_batch_size // self.num_generations - - if self.num_generations < 1: - raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") if self.generation_batch_size < 1: raise ValueError( "generation_batch_size must be at least 1. " From 022af620b915f168bf4ec6791958a84e89193be8 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 15:32:41 +0000 Subject: [PATCH 3/6] Add clarifying comment to num_generations --- trl/experimental/gold/gold_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 828f2f3595b..c0c0b0731d5 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -187,7 +187,7 @@ class GOLDConfig(SFTConfig): num_generations: int = field( default=1, metadata={ - "help": "Number of generations per prompt. Each prompt is repeated this many times in the batch." + "help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step." }, ) generation_batch_size: int | None = field( From 33e0a82f3799cdcdd2827a10dcab9f8e0c213a27 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 15:45:41 +0000 Subject: [PATCH 4/6] Patch issue with ZeRO-3 --- trl/experimental/gold/gold_trainer.py | 28 ++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 268a1866dd0..020989d8479 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -2030,6 +2030,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_head = unwrapped_teacher.get_output_embeddings() # liger fused jsd loss + # Note: with ZeRO-3 the lm_head weights are partitioned. The + # gathering is handled in training_step() which wraps both + # forward and backward in a GatheredParameters context. loss = self.liger_jsd_loss( student_input=student_hidden, student_weight=student_head.weight, @@ -2501,7 +2504,30 @@ def training_step( buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) - loss = super().training_step(model, inputs, num_items_in_batch) + # With Liger + ZeRO-3, the lm_head weights are partitioned across + # ranks. The Liger fused kernel saves weight references during forward + # and reads them during backward; both must see the full (gathered) + # weight. Wrapping super().training_step() (which runs compute_loss + + # accelerator.backward) keeps the weights gathered for both passes. + _gather_ctx = nullcontext() + if self.use_liger_gkd_loss: + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: + import deepspeed + + unwrapped = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if getattr(student_head, "bias", None) is not None: + params.append(student_head.bias) + if getattr(teacher_head, "bias", None) is not None: + params.append(teacher_head.bias) + _gather_ctx = deepspeed.zero.GatheredParameters(params, modifier_rank=None) + + with _gather_ctx: + loss = super().training_step(model, inputs, num_items_in_batch) slice_idx = (self._step - 1) % buffer_steps From dbb6e7059c4599aa0a53ce8c1813383f06ddb788 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 16:55:24 +0100 Subject: [PATCH 5/6] Refactor context for ZeRO-3 + Liger --- trl/experimental/gold/gold_trainer.py | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 020989d8479..452cc7e32af 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -2490,6 +2490,27 @@ def _wake_vllm_if_needed(self): empty_cache() self.vllm_engine.wake_up(tags=["kv_cache"]) + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + if not self.use_liger_gkd_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + @profiling_decorator def training_step( self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None @@ -2504,29 +2525,8 @@ def training_step( buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) ga = max(1, int(self.args.gradient_accumulation_steps)) - # With Liger + ZeRO-3, the lm_head weights are partitioned across - # ranks. The Liger fused kernel saves weight references during forward - # and reads them during backward; both must see the full (gathered) - # weight. Wrapping super().training_step() (which runs compute_loss + - # accelerator.backward) keeps the weights gathered for both passes. - _gather_ctx = nullcontext() - if self.use_liger_gkd_loss: - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: - import deepspeed - - unwrapped = self.accelerator.unwrap_model(model) - unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) - student_head = unwrapped.get_output_embeddings() - teacher_head = unwrapped_teacher.get_output_embeddings() - params = [student_head.weight, teacher_head.weight] - if getattr(student_head, "bias", None) is not None: - params.append(student_head.bias) - if getattr(teacher_head, "bias", None) is not None: - params.append(teacher_head.bias) - _gather_ctx = deepspeed.zero.GatheredParameters(params, modifier_rank=None) - - with _gather_ctx: + # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. + with self._get_liger_zero3_lm_head_gather_ctx(model): loss = super().training_step(model, inputs, num_items_in_batch) slice_idx = (self._step - 1) % buffer_steps From 9da54b3a67b5defcc7467ab7ac1aed631476b418 Mon Sep 17 00:00:00 2001 From: cmpatino Date: Mon, 2 Mar 2026 16:05:04 +0000 Subject: [PATCH 6/6] Simplify comments and code logic --- trl/experimental/gold/gold_config.py | 11 +++++------ trl/experimental/gold/gold_trainer.py | 12 +++++------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index c0c0b0731d5..8fbb42359d8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -409,15 +409,14 @@ def __post_init__(self): if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " - f"Expected generation_batch_size * num_generations == per_device_train_batch_size * " - f"gradient_accumulation_steps, but got {self.generation_batch_size} * {self.num_generations} != " - f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps} ({local_sequence_batch_size})." + "Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( - f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches will " - f"contain {self.num_generations} identical copies of each dataset sample. Consider setting " - f"lmbda=1.0 (fully on-policy) when using num_generations > 1.", + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.", UserWarning, stacklevel=2, ) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 452cc7e32af..0a83121efe6 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1106,9 +1106,8 @@ def get_train_dataloader(self): Override Trainer.get_train_dataloader to load one generation batch per optimizer window. The base dataloader yields local batches of size - `per_device_train_batch_size * gradient_accumulation_steps`. Each base batch is then repeated - `gradient_accumulation_steps` times so Trainer still executes one mini-step per accumulation slot without - re-sampling prompts. + `per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch + `gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1147,7 +1146,7 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di if not self.model.training: return generation_batch - buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) + buffer_steps = self.args.gradient_accumulation_steps if self._step % buffer_steps == 0 or self._buffered_inputs is None: self._fill_buffer(generation_batch, buffer_steps) @@ -2522,8 +2521,7 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - buffer_steps = max(1, int(self.args.gradient_accumulation_steps)) - ga = max(1, int(self.args.gradient_accumulation_steps)) + buffer_steps = self.args.gradient_accumulation_steps # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. with self._get_liger_zero3_lm_head_gather_ctx(model): @@ -2541,7 +2539,7 @@ def training_step( self._textual_logs["completion"].extend(gather_object(completion_texts)) loss_scalar = float(loss.detach()) - step_equiv = 1.0 / ga + step_equiv = 1.0 / self.args.gradient_accumulation_steps if on_policy: self._on_policy_loss_total += loss_scalar