diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index a95e7d084a9..98af2eaa240 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -29,8 +29,8 @@ messages). Important configuration flags on [`GOLDConfig`] include: matched/unmatched loss. * `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. -* `steps_per_generation`, `num_generations`, `generation_batch_size` – control buffered rollout generation across - gradient accumulation windows, including multi-generation sampling per prompt. +* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. + `generation_batch_size` is the number of unique prompts per worker per optimizer step. A minimal end-to-end example: diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index a4a4d60be96..8fbb42359d8 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -54,14 +54,11 @@ class GOLDConfig(SFTConfig): seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). - steps_per_generation (`int` or `None`, *optional*, defaults to `None`): - Number of optimization steps per generation. If `None`, it defaults to - `gradient_accumulation_steps`. num_generations (`int`, *optional*, defaults to `1`): Number of generations per prompt. Each prompt is repeated this many times in the generation batch. generation_batch_size (`int` or `None`, *optional*, defaults to `None`): - Number of prompts per generation batch (global, across all processes). If `None`, it is computed from - `per_device_train_batch_size * world_size * steps_per_generation`. + Number of unique prompts per worker per optimizer step. If `None`, it is computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. use_uld_loss (`bool`, *optional*, defaults to `False`): Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss. @@ -187,23 +184,17 @@ class GOLDConfig(SFTConfig): "FT on teacher-generated output)." }, ) - steps_per_generation: int | None = field( - default=None, - metadata={ - "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." - }, - ) num_generations: int = field( default=1, metadata={ - "help": "Number of generations per prompt. Each prompt is repeated this many times in the batch." + "help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step." }, ) generation_batch_size: int | None = field( default=None, metadata={ - "help": "Number of prompts per generation batch (global, across all processes). " - "If None, computed from per_device_train_batch_size * num_processes * steps_per_generation." + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." }, ) @@ -405,31 +396,27 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) - if self.steps_per_generation is None: - self.steps_per_generation = self.gradient_accumulation_steps - - if self.generation_batch_size is None: - self.generation_batch_size = ( - self.per_device_train_batch_size * self.world_size * self.steps_per_generation - ) - if self.num_generations < 1: raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") - if self.generation_batch_size % self.num_generations != 0: + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps + if self.generation_batch_size is None: + self.generation_batch_size = local_sequence_batch_size // self.num_generations + if self.generation_batch_size < 1: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " - f"({self.num_generations})." + "generation_batch_size must be at least 1. " + f"Got generation_batch_size={self.generation_batch_size}." ) - if self.generation_batch_size // self.num_generations < 1: + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: raise ValueError( - f"generation_batch_size ({self.generation_batch_size}) must be at least num_generations " - f"({self.num_generations}) so that each generation batch contains at least one unique prompt." + "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " + "Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." ) if self.num_generations > 1 and self.lmbda < 1.0: warnings.warn( - f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches will " - f"contain {self.num_generations} identical copies of each dataset sample. Consider setting " - f"lmbda=1.0 (fully on-policy) when using num_generations > 1.", + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.", UserWarning, stacklevel=2, ) diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 676130dd62a..0a83121efe6 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -234,6 +234,31 @@ def build_teacher_inputs_from_texts( return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length +class _RepeatEachBatchDataLoader: + """Repeats each dataloader batch `repeat_count` times without re-sampling.""" + + def __init__(self, dataloader, repeat_count: int): + if repeat_count < 1: + raise ValueError(f"repeat_count must be at least 1, got {repeat_count}.") + self.dataloader = dataloader + self.repeat_count = repeat_count + + def __iter__(self): + for batch in self.dataloader: + for _ in range(self.repeat_count): + yield batch + + def __len__(self): + return len(self.dataloader) * self.repeat_count + + def set_epoch(self, epoch: int): + if hasattr(self.dataloader, "set_epoch"): + self.dataloader.set_epoch(epoch) + + def __getattr__(self, attr): + return getattr(self.dataloader, attr) + + class ULDLoss(nn.Module): """ Universal Logit Distillation Loss. @@ -951,7 +976,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the # final optimization step. - maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps self._textual_logs = { "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), @@ -1070,20 +1095,19 @@ def _get_train_sampler(self, dataset=None): return RepeatSampler( data_source=dataset, mini_repeat_count=self.num_generations, - batch_size=self.args.generation_batch_size // self.num_generations, - repeat_count=self.args.steps_per_generation, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=1, shuffle=True, seed=self.args.seed, ) def get_train_dataloader(self): """ - Override Trainer.get_train_dataloader to load a generation batch covering one optimizer window. + Override Trainer.get_train_dataloader to load one generation batch per optimizer window. - Instead of returning a standard per-step batch (i.e., `per_device_batch_size`), this dataloader loads - a batch of size `per_device_batch_size * steps_per_generation`. Combined with the `RepeatSampler` - (which inflates the sampler length by `steps_per_generation`), this prevents the Trainer from - double-dividing by `gradient_accumulation_steps` when computing optimizer steps per epoch. + The base dataloader yields local batches of size + `per_device_train_batch_size * gradient_accumulation_steps`, then repeats each batch + `gradient_accumulation_steps` times so Trainer can run accumulation mini-steps without re-sampling prompts. """ if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") @@ -1096,7 +1120,7 @@ def get_train_dataloader(self): data_collator = self._get_collator_with_removed_columns(data_collator, description="training") dataloader_params = { - "batch_size": self._train_batch_size * self.args.steps_per_generation, + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, @@ -1114,18 +1138,19 @@ def get_train_dataloader(self): if self.args.dataloader_num_workers > 0: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + base_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return _RepeatEachBatchDataLoader(base_dataloader, repeat_count=self.args.gradient_accumulation_steps) @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: if not self.model.training: return generation_batch - spg = self.args.steps_per_generation - if self._step % spg == 0 or self._buffered_inputs is None: - self._fill_buffer(generation_batch, spg) + buffer_steps = self.args.gradient_accumulation_steps + if self._step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) - slice_idx = self._step % spg + slice_idx = self._step % buffer_steps inputs = self._buffered_inputs[slice_idx] self._step += 1 return inputs @@ -1178,20 +1203,20 @@ def _ensure_original_text_fields(self, slice_inputs: dict[str, torch.Tensor | An return updated_slice @profiling_decorator - def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], spg: int): - slices = split_tensor_dict(generation_batch, spg) + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + slices = split_tensor_dict(generation_batch, buffer_steps) if self.accelerator.is_main_process: - on_policy_flags = [random.random() <= self.lmbda for _ in range(spg)] + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] else: - on_policy_flags = [False] * spg + on_policy_flags = [False] * buffer_steps on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag] - self._buffered_inputs = [None] * spg + self._buffered_inputs = [None] * buffer_steps self._buffered_on_policy = on_policy_flags - self._buffered_text_logs = [None] * spg + self._buffered_text_logs = [None] * buffer_steps for i, flag in enumerate(on_policy_flags): if not flag: @@ -2004,6 +2029,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_head = unwrapped_teacher.get_output_embeddings() # liger fused jsd loss + # Note: with ZeRO-3 the lm_head weights are partitioned. The + # gathering is handled in training_step() which wraps both + # forward and backward in a GatheredParameters context. loss = self.liger_jsd_loss( student_input=student_hidden, student_weight=student_head.weight, @@ -2461,6 +2489,27 @@ def _wake_vllm_if_needed(self): empty_cache() self.vllm_engine.wake_up(tags=["kv_cache"]) + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + if not self.use_liger_gkd_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + @profiling_decorator def training_step( self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None @@ -2472,12 +2521,13 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - spg = self.args.steps_per_generation - ga = max(1, int(self.args.gradient_accumulation_steps)) + buffer_steps = self.args.gradient_accumulation_steps - loss = super().training_step(model, inputs, num_items_in_batch) + # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. + with self._get_liger_zero3_lm_head_gather_ctx(model): + loss = super().training_step(model, inputs, num_items_in_batch) - slice_idx = (self._step - 1) % spg + slice_idx = (self._step - 1) % buffer_steps on_policy = False if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy): @@ -2489,7 +2539,7 @@ def training_step( self._textual_logs["completion"].extend(gather_object(completion_texts)) loss_scalar = float(loss.detach()) - step_equiv = 1.0 / ga + step_equiv = 1.0 / self.args.gradient_accumulation_steps if on_policy: self._on_policy_loss_total += loss_scalar