Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/gold_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ messages). Important configuration flags on [`GOLDConfig`] include:
matched/unmatched loss.
* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy
sampling ratio.
* `steps_per_generation`, `num_generations`, `generation_batch_size` – control buffered rollout generation across
gradient accumulation windows, including multi-generation sampling per prompt.
* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows.
`generation_batch_size` is the number of unique prompts per worker per optimizer step.

A minimal end-to-end example:

Expand Down
49 changes: 18 additions & 31 deletions trl/experimental/gold/gold_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,11 @@ class GOLDConfig(SFTConfig):
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
steps_per_generation (`int` or `None`, *optional*, defaults to `None`):
Number of optimization steps per generation. If `None`, it defaults to
`gradient_accumulation_steps`.
num_generations (`int`, *optional*, defaults to `1`):
Number of generations per prompt. Each prompt is repeated this many times in the generation batch.
generation_batch_size (`int` or `None`, *optional*, defaults to `None`):
Number of prompts per generation batch (global, across all processes). If `None`, it is computed from
`per_device_train_batch_size * world_size * steps_per_generation`.
Number of unique prompts per worker per optimizer step. If `None`, it is computed from
`(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`.
use_uld_loss (`bool`, *optional*, defaults to `False`):
Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence
loss.
Expand Down Expand Up @@ -187,23 +184,17 @@ class GOLDConfig(SFTConfig):
"FT on teacher-generated output)."
},
)
steps_per_generation: int | None = field(
default=None,
metadata={
"help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
},
)
num_generations: int = field(
default=1,
metadata={
"help": "Number of generations per prompt. Each prompt is repeated this many times in the batch."
"help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step."
},
)
generation_batch_size: int | None = field(
default=None,
metadata={
"help": "Number of prompts per generation batch (global, across all processes). "
"If None, computed from per_device_train_batch_size * num_processes * steps_per_generation."
"help": "Number of unique prompts per worker per optimizer step. "
"If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations."
},
)

Expand Down Expand Up @@ -405,31 +396,27 @@ def __post_init__(self):
f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length."
)

if self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps

if self.generation_batch_size is None:
self.generation_batch_size = (
self.per_device_train_batch_size * self.world_size * self.steps_per_generation
)

if self.num_generations < 1:
raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.")
if self.generation_batch_size % self.num_generations != 0:
local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps
if self.generation_batch_size is None:
self.generation_batch_size = local_sequence_batch_size // self.num_generations
if self.generation_batch_size < 1:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations "
f"({self.num_generations})."
"generation_batch_size must be at least 1. "
f"Got generation_batch_size={self.generation_batch_size}."
)
if self.generation_batch_size // self.num_generations < 1:
if self.generation_batch_size * self.num_generations != local_sequence_batch_size:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be at least num_generations "
f"({self.num_generations}) so that each generation batch contains at least one unique prompt."
"generation_batch_size and num_generations must exactly partition the local optimizer-step batch. "
"Expected generation_batch_size * num_generations == per_device_train_batch_size * "
f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != "
f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}."
)
if self.num_generations > 1 and self.lmbda < 1.0:
warnings.warn(
f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches will "
f"contain {self.num_generations} identical copies of each dataset sample. Consider setting "
f"lmbda=1.0 (fully on-policy) when using num_generations > 1.",
f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include "
f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.",
UserWarning,
stacklevel=2,
)
Expand Down
100 changes: 75 additions & 25 deletions trl/experimental/gold/gold_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,31 @@ def build_teacher_inputs_from_texts(
return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length


class _RepeatEachBatchDataLoader:
"""Repeats each dataloader batch `repeat_count` times without re-sampling."""

def __init__(self, dataloader, repeat_count: int):
if repeat_count < 1:
raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.")
self.dataloader = dataloader
self.repeat_count = repeat_count

def __iter__(self):
for batch in self.dataloader:
for _ in range(self.repeat_count):
yield batch

def __len__(self):
return len(self.dataloader) * self.repeat_count

def set_epoch(self, epoch: int):
if hasattr(self.dataloader, "set_epoch"):
self.dataloader.set_epoch(epoch)

def __getattr__(self, attr):
return getattr(self.dataloader, attr)


class ULDLoss(nn.Module):
"""
Universal Logit Distillation Loss.
Expand Down Expand Up @@ -951,7 +976,7 @@ def __init__(
self.num_completions_to_print = args.num_completions_to_print
# maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
# final optimization step.
maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation
maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
self._textual_logs = {
"prompt": deque(maxlen=maxlen),
"completion": deque(maxlen=maxlen),
Expand Down Expand Up @@ -1070,20 +1095,19 @@ def _get_train_sampler(self, dataset=None):
return RepeatSampler(
data_source=dataset,
mini_repeat_count=self.num_generations,
batch_size=self.args.generation_batch_size // self.num_generations,
repeat_count=self.args.steps_per_generation,
batch_size=self.args.generation_batch_size * self.accelerator.num_processes,
repeat_count=1,
shuffle=True,
seed=self.args.seed,
)

def get_train_dataloader(self):
"""
Override Trainer.get_train_dataloader to load a generation batch covering one optimizer window.
Override Trainer.get_train_dataloader to load one generation batch per optimizer window.

Instead of returning a standard per-step batch (i.e., `per_device_batch_size`), this dataloader loads
a batch of size `per_device_batch_size * steps_per_generation`. Combined with the `RepeatSampler`
(which inflates the sampler length by `steps_per_generation`), this prevents the Trainer from
double-dividing by `gradient_accumulation_steps` when computing optimizer steps per epoch.
The base dataloader yields local batches of size
`per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch
`gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
Expand All @@ -1096,7 +1120,7 @@ def get_train_dataloader(self):
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

dataloader_params = {
"batch_size": self._train_batch_size * self.args.steps_per_generation,
"batch_size": self._train_batch_size * self.args.gradient_accumulation_steps,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
Expand All @@ -1114,18 +1138,19 @@ def get_train_dataloader(self):
if self.args.dataloader_num_workers > 0:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return _RepeatEachBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps)

@profiling_decorator
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
if not self.model.training:
return generation_batch

spg = self.args.steps_per_generation
if self._step % spg == 0 or self._buffered_inputs is None:
self._fill_buffer(generation_batch, spg)
buffer_steps = self.args.gradient_accumulation_steps
if self._step % buffer_steps == 0 or self._buffered_inputs is None:
self._fill_buffer(generation_batch, buffer_steps)

slice_idx = self._step % spg
slice_idx = self._step % buffer_steps
inputs = self._buffered_inputs[slice_idx]
self._step += 1
return inputs
Expand Down Expand Up @@ -1178,20 +1203,20 @@ def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | An
return updated_slice

@profiling_decorator
def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], spg: int):
slices = split_tensor_dict(generation_batch, spg)
def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int):
slices = split_tensor_dict(generation_batch, buffer_steps)

if self.accelerator.is_main_process:
on_policy_flags = [random.random() <= self.lmbda for _ in range(spg)]
on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)]
else:
on_policy_flags = [False] * spg
on_policy_flags = [False] * buffer_steps

on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0)
on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag]

self._buffered_inputs = [None] * spg
self._buffered_inputs = [None] * buffer_steps
self._buffered_on_policy = on_policy_flags
self._buffered_text_logs = [None] * spg
self._buffered_text_logs = [None] * buffer_steps

for i, flag in enumerate(on_policy_flags):
if not flag:
Expand Down Expand Up @@ -2004,6 +2029,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
teacher_head = unwrapped_teacher.get_output_embeddings()

# liger fused jsd loss
# Note: with ZeRO-3 the lm_head weights are partitioned. The
# gathering is handled in training_step() which wraps both
# forward and backward in a GatheredParameters context.
loss = self.liger_jsd_loss(
student_input=student_hidden,
student_weight=student_head.weight,
Expand Down Expand Up @@ -2461,6 +2489,27 @@ def _wake_vllm_if_needed(self):
empty_cache()
self.vllm_engine.wake_up(tags=["kv_cache"])

def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module):
if not self.use_liger_gkd_loss:
return nullcontext()

deepspeed_plugin = self.accelerator.state.deepspeed_plugin
if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3:
return nullcontext()

import deepspeed

unwrapped_student = self.accelerator.unwrap_model(model)
unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model)
student_head = unwrapped_student.get_output_embeddings()
teacher_head = unwrapped_teacher.get_output_embeddings()
params = [student_head.weight, teacher_head.weight]
if student_head.bias is not None:
params.append(student_head.bias)
if teacher_head.bias is not None:
params.append(teacher_head.bias)
return deepspeed.zero.GatheredParameters(params, modifier_rank=None)

@profiling_decorator
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
Expand All @@ -2472,12 +2521,13 @@ def training_step(
`self.lmbda`, it generates new responses using the student model, which are then used for training instead of
the offline original inputs.
"""
spg = self.args.steps_per_generation
ga = max(1, int(self.args.gradient_accumulation_steps))
buffer_steps = self.args.gradient_accumulation_steps

loss = super().training_step(model, inputs, num_items_in_batch)
# Keep lm_head gathered across forward+backward for Liger + ZeRO-3.
with self._get_liger_zero3_lm_head_gather_ctx(model):
loss = super().training_step(model, inputs, num_items_in_batch)

slice_idx = (self._step - 1) % spg
slice_idx = (self._step - 1) % buffer_steps

on_policy = False
if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy):
Expand All @@ -2489,7 +2539,7 @@ def training_step(
self._textual_logs["completion"].extend(gather_object(completion_texts))

loss_scalar = float(loss.detach())
step_equiv = 1.0 / ga
step_equiv = 1.0 / self.args.gradient_accumulation_steps

if on_policy:
self._on_policy_loss_total += loss_scalar
Expand Down