diff --git a/docs/source/gold_trainer.md b/docs/source/gold_trainer.md index dbe7e7b01e6..72685f7ae80 100644 --- a/docs/source/gold_trainer.md +++ b/docs/source/gold_trainer.md @@ -29,6 +29,9 @@ messages). Important configuration flags on [`GOLDConfig`] include: matched/unmatched loss. * `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy sampling ratio. +* `num_generations`, `generation_batch_size` – control buffered rollout generation across gradient accumulation windows. + `generation_batch_size` is the number of unique prompts per worker per optimizer step. +* `model_revision` – controls which student model revision GOLD loads for training and generation. A minimal end-to-end example: @@ -79,7 +82,7 @@ train_dataset = load_dataset( training_args = GOLDConfig( output_dir="gold-model", per_device_train_batch_size=1, - teacher_model=teacher_name, + teacher_model_name_or_path=teacher_name, teacher_tokenizer_name_or_path=teacher_name, use_uld_loss=True, uld_use_hybrid_loss=True, @@ -95,6 +98,11 @@ trainer = GOLDTrainer( trainer.train() ``` +> [!NOTE] +> GOLD buffers one full optimizer-window generation batch (`per_device_train_batch_size * gradient_accumulation_steps`) +> and reuses it across accumulation steps. If the final batch is undersized, GOLD warns and drops that last batch +> (`Dropping last batch due to unexpected batch size`). Set `dataloader_drop_last=True` to avoid this warning. + ### Expected dataset type GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.: diff --git a/tests/experimental/test_gold_trainer.py b/tests/experimental/test_gold_trainer.py index f48f3464654..50800c0a136 100644 --- a/tests/experimental/test_gold_trainer.py +++ b/tests/experimental/test_gold_trainer.py @@ -289,6 +289,88 @@ def pad_labels(labels, target_length): return labels + [-100] * (target_length - len(labels)) +def test_process_completions_to_buffer_left_pads_prompt_retokenization(): + class DummyBatch: + def __init__(self, input_ids): + self.input_ids = input_ids + + def to(self, device): + self.input_ids = self.input_ids.to(device) + return self + + class RecordingTokenizer: + pad_token_id = 0 + pad_token = "" + + def __init__(self): + self.padding_side = "right" + self.calls = [] + self._prompt_ids = { + "short": [11], + "longer": [21, 22], + } + + def __call__( + self, + texts, + return_tensors, + padding, + truncation, + max_length, + add_special_tokens, + padding_side=None, + ): + assert return_tensors == "pt" + assert padding == "longest" + assert not truncation + assert max_length is None + assert not add_special_tokens + self.calls.append(padding_side) + + side = padding_side or self.padding_side + encoded = [torch.tensor(self._prompt_ids[text], dtype=torch.long) for text in texts] + max_len = max(len(ids) for ids in encoded) + + padded = [] + for ids in encoded: + pad_width = max_len - len(ids) + if pad_width: + pad = torch.full((pad_width,), self.pad_token_id, dtype=torch.long) + ids = torch.cat([pad, ids]) if side == "left" else torch.cat([ids, pad]) + padded.append(ids) + + return DummyBatch(torch.stack(padded)) + + def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=False): + del skip_special_tokens, clean_up_tokenization_spaces + return [" ".join(str(token) for token in sequence) for sequence in sequences] + + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.accelerator = SimpleNamespace(device=torch.device("cpu")) + trainer.processing_class = RecordingTokenizer() + trainer.args = SimpleNamespace(max_length=None) + trainer._buffered_inputs = [None] + trainer._buffered_text_logs = [None] + + GOLDTrainer._process_completions_to_buffer( + trainer, + slices=[{"slice": "original"}], + on_policy_indices=[0], + local_slice_indices=[0, 0], + completion_ids=[[31], [41]], + prompts_text=["short", "longer"], + prompts_text_with_special=["short", "longer"], + max_completion_length=1, + ) + + buffered_inputs = trainer._buffered_inputs[0] + assert trainer.processing_class.calls == ["left"] + assert trainer.processing_class.padding_side == "right" + assert torch.equal(buffered_inputs["input_ids"], torch.tensor([[0, 11, 31], [21, 22, 41]], dtype=torch.long)) + assert torch.equal(buffered_inputs["attention_mask"], torch.tensor([[0, 1, 1], [1, 1, 1]], dtype=torch.long)) + assert torch.equal(buffered_inputs["labels"], torch.tensor([[-100, -100, 31], [-100, -100, 41]])) + + def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer): config = build_config() loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) @@ -382,7 +464,9 @@ def test_generate_on_policy_outputs_masks_prompt(llama_tokenizer): prompt_tensor[0, pad_width:] = torch.tensor(prompt_ids, dtype=torch.long) prompt_mask = (prompt_tensor != pad_id).long() - generated_sequence = torch.tensor(prompt_ids + completion_ids, dtype=torch.long).unsqueeze(0) + # model.generate() returns full sequences including left-padding from the input + completion_tensor = torch.tensor(completion_ids, dtype=torch.long).unsqueeze(0) + generated_sequence = torch.cat([prompt_tensor, completion_tensor], dim=1) class DummyModel: def generate(self, input_ids, attention_mask, generation_config, return_dict_in_generate): @@ -406,9 +490,9 @@ def generate(self, input_ids, attention_mask, generation_config, return_dict_in_ else: assert torch.all(new_mask == 1) - prompt_len = len(prompt_ids) - assert torch.all(new_labels[0, :prompt_len] == -100) - assert torch.equal(new_labels[0, prompt_len:], torch.tensor(completion_ids, dtype=torch.long)) + padded_prompt_len = prompt_tensor.shape[1] + assert torch.all(new_labels[0, :padded_prompt_len] == -100) + assert torch.equal(new_labels[0, padded_prompt_len:], torch.tensor(completion_ids, dtype=torch.long)) assert prompt_texts[0] == llama_tokenizer.decode(prompt_ids, skip_special_tokens=False) assert completion_texts[0] == llama_tokenizer.decode(completion_ids, skip_special_tokens=False) diff --git a/trl/experimental/gold/gold.py b/trl/experimental/gold/gold.py index 81954b5c753..399344eab20 100644 --- a/trl/experimental/gold/gold.py +++ b/trl/experimental/gold/gold.py @@ -80,7 +80,7 @@ ################ quantization_config = get_quantization_config(model_args) model_kwargs = dict( - revision=training_args.student_model_revision, + revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, @@ -93,7 +93,7 @@ if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss: training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path teacher_model_kwargs = dict( - revision=model_args.model_revision, + revision=training_args.teacher_model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.dtype, @@ -101,13 +101,14 @@ device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) + if training_args.teacher_model_init_kwargs is not None: + teacher_model_kwargs.update(training_args.teacher_model_init_kwargs) training_args.teacher_model_init_kwargs = teacher_model_kwargs tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - padding_side="left", ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -120,7 +121,6 @@ ################ # Training ################ - # Handle eval dataset - check if test split exists, fallback to validation or None eval_dataset = None if training_args.eval_strategy != "no": if script_args.dataset_test_split in dataset: diff --git a/trl/experimental/gold/gold_config.py b/trl/experimental/gold/gold_config.py index 41839318f76..2a0984e6892 100644 --- a/trl/experimental/gold/gold_config.py +++ b/trl/experimental/gold/gold_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Any @@ -40,6 +41,9 @@ class GOLDConfig(SFTConfig): teacher_model_name_or_path (`str`, *optional*): Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being trained. + teacher_model_revision (`str` or `None`, *optional*, defaults to `None`): + Model revision of the teacher model (e.g., branch name, tag, or commit hash). If `None`, the default + revision is used. teacher_model_init_kwargs (`dict[str, Any]`, *optional*): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. @@ -51,6 +55,11 @@ class GOLDConfig(SFTConfig): seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). + num_generations (`int`, *optional*, defaults to `1`): + Number of generations per prompt. Each prompt is repeated this many times in the generation batch. + generation_batch_size (`int` or `None`, *optional*, defaults to `None`): + Number of unique prompts per worker per optimizer step. If `None`, it is computed from + `(per_device_train_batch_size * gradient_accumulation_steps) // num_generations`. use_uld_loss (`bool`, *optional*, defaults to `False`): Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss. @@ -138,12 +147,6 @@ class GOLDConfig(SFTConfig): default=128, metadata={"help": "Maximum number of tokens to generate per completion."}, ) - student_model_revision: str = field( - default="main", - metadata={ - "help": "Revision of the student model to use. If not specified, the default revision of the model will be used." - }, - ) teacher_model_name_or_path: str | None = field( default=None, metadata={ @@ -151,6 +154,13 @@ class GOLDConfig(SFTConfig): "model being trained." }, ) + teacher_model_revision: str | None = field( + default=None, + metadata={ + "help": "Model revision of the teacher model (e.g., branch name, tag, or commit hash). If `None`, the " + "default revision is used." + }, + ) teacher_model_init_kwargs: dict[str, Any] | str | None = field( default=None, metadata={ @@ -176,10 +186,17 @@ class GOLDConfig(SFTConfig): "FT on teacher-generated output)." }, ) - steps_per_generation: int | None = field( + num_generations: int = field( + default=1, + metadata={ + "help": "Number of generations per prompt. Increasing this will decrease the number of unique prompts per optimization step." + }, + ) + generation_batch_size: int | None = field( default=None, metadata={ - "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." + "help": "Number of unique prompts per worker per optimizer step. " + "If None, computed from (per_device_train_batch_size * gradient_accumulation_steps) // num_generations." }, ) @@ -362,15 +379,8 @@ class GOLDConfig(SFTConfig): hub_model_revision: str | None = field( default="main", metadata={"help": "The Hub model branch to push the model to."} ) - num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) - trl_project: str = field( - default="smollm3", - metadata={ - "help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script." - }, - ) def __post_init__(self): super().__post_init__() @@ -387,8 +397,29 @@ def __post_init__(self): f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." ) - if self.steps_per_generation is None: - self.steps_per_generation = self.gradient_accumulation_steps + if self.num_generations < 1: + raise ValueError(f"num_generations must be at least 1, got {self.num_generations}.") + local_sequence_batch_size = self.per_device_train_batch_size * self.gradient_accumulation_steps + if self.generation_batch_size is None: + self.generation_batch_size = local_sequence_batch_size // self.num_generations + if self.generation_batch_size < 1: + raise ValueError( + f"generation_batch_size must be at least 1. Got generation_batch_size={self.generation_batch_size}." + ) + if self.generation_batch_size * self.num_generations != local_sequence_batch_size: + raise ValueError( + "generation_batch_size and num_generations must exactly partition the local optimizer-step batch. " + "Expected generation_batch_size * num_generations == per_device_train_batch_size * " + f"gradient_accumulation_steps, got {self.generation_batch_size} * {self.num_generations} != " + f"{self.per_device_train_batch_size} * {self.gradient_accumulation_steps}." + ) + if self.num_generations > 1 and self.lmbda < 1.0: + warnings.warn( + f"num_generations={self.num_generations} with lmbda={self.lmbda} means off-policy batches include " + f"{self.num_generations} copies of each sample; consider lmbda=1.0 when num_generations > 1.", + UserWarning, + stacklevel=2, + ) # Validate ULD parameters if self.use_uld_loss: diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index c588b3b2a42..cc4f3f2553a 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -19,6 +19,7 @@ from collections import defaultdict, deque from collections.abc import Callable from contextlib import nullcontext +from functools import partial from typing import Any, Optional import torch @@ -29,6 +30,7 @@ from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model from datasets import Dataset, IterableDataset from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState, is_bitsandbytes_available from transformers.data.data_collator import DataCollator from transformers.feature_extraction_utils import FeatureExtractionMixin @@ -38,8 +40,9 @@ from transformers.modeling_utils import PreTrainedModel from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer_utils import EvalPrediction +from transformers.trainer_utils import EvalPrediction, seed_worker from transformers.utils import ( + is_datasets_available, is_flash_attn_2_available, is_liger_kernel_available, is_peft_available, @@ -53,7 +56,14 @@ from ...models import prepare_deepspeed from ...models.utils import unwrap_model_for_generation from ...trainer.sft_trainer import SFTTrainer -from ...trainer.utils import create_model_from_path, disable_dropout_in_model, ensure_master_addr_port, pad +from ...trainer.utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + ensure_master_addr_port, + pad, + split_tensor_dict, +) from ..utils import DataCollatorForChatML, empty_cache from .gold_config import GOLDConfig @@ -796,10 +806,7 @@ def __init__( peft_config: Optional["PeftConfig"] = None, ): self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path - self.model_revision = getattr(args, "student_model_revision", None) - if isinstance(model, str) and self.model_revision is not None: - args.model_init_kwargs = args.model_init_kwargs or {} - args.model_init_kwargs.setdefault("revision", self.model_revision) + self.model_revision = (args.model_init_kwargs or {}).get("revision") # Respect a user-provided data_collator; otherwise, provide a ChatML collator that if data_collator is None: @@ -813,6 +820,8 @@ def __init__( ignore_index=-100, temperature=args.temperature, compiled=False, + weight_hard_loss=0.0, + weight_soft_loss=1.0, ) self.use_liger_gkd_loss = True @@ -840,6 +849,8 @@ def __init__( if isinstance(teacher_model, str): init_kwargs = dict(teacher_model_init_kwargs) + if args.teacher_model_revision is not None: + init_kwargs.setdefault("revision", args.teacher_model_revision) if "torch_dtype" in init_kwargs and "dtype" not in init_kwargs: init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") teacher_model = create_model_from_path(teacher_model, **init_kwargs) @@ -881,6 +892,7 @@ def __init__( self.temperature = args.temperature self.top_p = args.top_p self.seq_kd = args.seq_kd + self.num_generations = args.num_generations # Track per-step loss statistics for on/off-policy batches (used in logging) self._on_policy_loss_total = 0.0 @@ -888,6 +900,12 @@ def __init__( self._on_policy_step_equiv = 0.0 self._off_policy_step_equiv = 0.0 + # Buffering for rollouts across gradient accumulation steps + self._buffered_inputs = None + self._buffered_on_policy = None + self._buffered_text_logs = None + self._step = 0 + # Hybrid ULD matched/unmatched accumulators (logged every step when ULD hybrid is used) self._matched_sum = 0.0 self._unmatched_sum = 0.0 @@ -931,7 +949,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the # final optimization step. - maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps self._textual_logs = { "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), @@ -1044,6 +1062,505 @@ def _set_signature_columns_if_needed(self): if column not in self._signature_columns: self._signature_columns.append(column) + def _get_train_sampler(self, dataset=None): + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size * self.accelerator.num_processes, + repeat_count=self.args.gradient_accumulation_steps, + shuffle=True, + seed=self.args.seed, + ) + + def get_train_dataloader(self): + """ + Override Trainer.get_train_dataloader to load one generation batch per optimizer window. + + The dataloader yields local batches of size `per_device_train_batch_size * gradient_accumulation_steps`. + The `RepeatSampler` (with `repeat_count=gradient_accumulation_steps`) ensures each generation batch is + sampled `gradient_accumulation_steps` times so Trainer's loop iterates the correct number of times. + Only the first batch in each window triggers `_fill_buffer`; the rest are ignored by `_prepare_inputs`. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.gradient_accumulation_steps, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if self.args.dataloader_num_workers > 0: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + if not self.model.training: + return generation_batch + + buffer_steps = self.args.gradient_accumulation_steps + if self._step % buffer_steps == 0 or self._buffered_inputs is None: + self._fill_buffer(generation_batch, buffer_steps) + + slice_idx = self._step % buffer_steps + inputs = self._buffered_inputs[slice_idx] + self._step += 1 + return inputs + + def _decode_completion_texts_from_labels(self, slice_inputs: dict[str, torch.Tensor | Any]) -> list[str] | None: + """Decode completion text from labels when raw text is absent.""" + labels = slice_inputs.get("labels") + if labels is None or not isinstance(labels, torch.Tensor): + return None + + labels_cpu = labels.detach().cpu() + decoded_completion_tokens: list[list[int]] = [] + for row in labels_cpu: + token_ids = row[row != -100].tolist() + if self.processing_class.pad_token_id is not None: + token_ids = [tok for tok in token_ids if tok != self.processing_class.pad_token_id] + decoded_completion_tokens.append(token_ids) + + return self.processing_class.batch_decode( + decoded_completion_tokens, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + def _ensure_original_text_fields( + self, slice_inputs: dict[str, torch.Tensor | Any] + ) -> dict[str, torch.Tensor | Any]: + """Populate original prompt/completion text fields when missing.""" + if "original_prompt_text" in slice_inputs and "original_completion_text" in slice_inputs: + return slice_inputs + + prompts = slice_inputs.get("prompts") + if prompts is None or not isinstance(prompts, torch.Tensor): + return slice_inputs + + prompt_texts = self.processing_class.batch_decode( + prompts, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + completion_texts = self._decode_completion_texts_from_labels(slice_inputs) + if completion_texts is None: + return slice_inputs + + updated_slice = dict(slice_inputs) + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts + return updated_slice + + @staticmethod + def _build_sequence_batch( + new_input_ids: torch.Tensor, prompt_lengths: torch.Tensor, pad_token_id: int | None + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build attention mask and labels from full sequences and prompt lengths.""" + prompt_lengths = prompt_lengths.to(device=new_input_ids.device, dtype=torch.long) + positions = torch.arange(new_input_ids.shape[1], device=new_input_ids.device).unsqueeze(0) + completion_mask = positions >= prompt_lengths.unsqueeze(1) + + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 + + new_labels = torch.full_like(new_input_ids, -100) + new_labels[completion_mask] = new_input_ids[completion_mask] + if pad_token_id is not None: + new_labels[new_input_ids == pad_token_id] = -100 + + return new_attention_mask, new_labels + + @profiling_decorator + def _fill_buffer(self, generation_batch: dict[str, torch.Tensor | Any], buffer_steps: int): + slices = split_tensor_dict(generation_batch, buffer_steps) + + if self.accelerator.is_main_process: + on_policy_flags = [random.random() <= self.lmbda for _ in range(buffer_steps)] + else: + on_policy_flags = [False] * buffer_steps + + on_policy_flags = broadcast_object_list(on_policy_flags, from_process=0) + on_policy_indices = [i for i, flag in enumerate(on_policy_flags) if flag] + + self._buffered_inputs = [None] * buffer_steps + self._buffered_on_policy = on_policy_flags + self._buffered_text_logs = [None] * buffer_steps + + for i, flag in enumerate(on_policy_flags): + if not flag: + slice_inputs = slices[i] + + if self.use_uld_loss and self.teacher_tokenizer is not None: + slice_inputs = self._ensure_original_text_fields(slice_inputs) + if "original_prompt_text" not in slice_inputs or "original_completion_text" not in slice_inputs: + raise ValueError( + "Off-policy batch missing 'original_prompt_text' or 'original_completion_text' fields. " + "When using ULD loss with cross-tokenizer alignment, datasets must be prepared with " + "_prepare_dataset_with_original_text(). Ensure your dataset includes these fields." + ) + + self._buffered_inputs[i] = slice_inputs + + if on_policy_indices: + self._generate_on_policy_for_slices(slices, on_policy_indices) + + @profiling_decorator + def _generate_on_policy_for_slices( + self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int] + ): + local_prompts = [] + local_slice_indices = [] + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + for prompt in slice_inputs["prompts"]: + local_prompts.append(prompt) + local_slice_indices.append(slice_idx) + + prompts_text_for_vllm = self.processing_class.batch_decode( + torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + skip_special_tokens=True, + ) + if self.processing_class.pad_token: + prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] + + prompts_text_with_special = self.processing_class.batch_decode( + torch.stack(local_prompts) if local_prompts else torch.empty(0, dtype=torch.long), + skip_special_tokens=False, + ) + + if self.use_vllm: + self._wake_vllm_if_needed() + + max_completion_length = self.generation_config.max_new_tokens + temperature = self.generation_config.temperature + top_k = ( + self.generation_config.top_k if self.generation_config.top_k and self.generation_config.top_k > 0 else -1 + ) + top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 + repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 + min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 + + if self.use_vllm and self.vllm_mode == "server": + completion_ids = self._generate_vllm_server_global( + prompts_text_for_vllm, + max_completion_length, + temperature, + top_k, + top_p, + repetition_penalty, + min_p, + n=self.num_generations, + ) + elif self.use_vllm and self.vllm_mode == "colocate": + completion_ids = self._generate_vllm_colocate( + prompts_text_for_vllm, + max_completion_length, + temperature, + top_k, + top_p, + repetition_penalty, + min_p, + n=self.num_generations, + ) + else: + self._generate_non_vllm_for_slices(slices, on_policy_indices) + return + + self._process_completions_to_buffer( + slices, + on_policy_indices, + local_slice_indices, + completion_ids, + prompts_text_for_vllm, + prompts_text_with_special, + max_completion_length, + ) + + @staticmethod + def _deduplicate_prompts( + prompts: list[str], num_generations: int + ) -> tuple[list[str], list[tuple[int, int]]] | None: + """Deduplicate prompts and build a completion remapping.""" + seen: dict[str, list[int]] = {} + unique_prompts: list[str] = [] + dedup_mapping: list[tuple[int, int]] = [] + + for prompt in prompts: + if prompt not in seen: + seen[prompt] = [len(unique_prompts), 0] + unique_prompts.append(prompt) + entry = seen[prompt] + if entry[1] >= num_generations: + return None + dedup_mapping.append((entry[0], entry[1])) + entry[1] += 1 + + return unique_prompts, dedup_mapping + + def _generate_vllm_server_global( + self, + prompts_text: list[str], + max_tokens: int, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + min_p: float, + n: int = 1, + ) -> list: + all_prompts_text = gather_object(prompts_text) + local_count = len(prompts_text) + + if self.accelerator.is_main_process: + if all_prompts_text: + dedup_mapping = None + if n > 1: + dedup_result = self._deduplicate_prompts(all_prompts_text, n) + if dedup_result is not None: + gen_prompts, dedup_mapping = dedup_result + gen_n = n + else: + gen_prompts = all_prompts_text + gen_n = 1 + else: + gen_prompts = all_prompts_text + gen_n = 1 + + completion_ids = self.vllm_client.generate( + prompts=gen_prompts, + n=gen_n, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + structured_outputs_regex=self.vllm_structured_outputs_regex, + )["completion_ids"] + + if dedup_mapping is not None: + completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] + else: + completion_ids = [] + else: + completion_ids = [None] * len(all_prompts_text) if all_prompts_text else [] + + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * local_count, + (self.accelerator.process_index + 1) * local_count, + ) + return completion_ids[process_slice] + + def _generate_vllm_colocate( + self, + prompts_text: list[str], + max_tokens: int, + temperature: float, + top_k: int, + top_p: float, + repetition_penalty: float, + min_p: float, + n: int = 1, + ) -> list: + if self.vllm_structured_outputs_regex: + structured_outputs = StructuredOutputsParams(backend="outlines", regex=self.vllm_structured_outputs_regex) + else: + structured_outputs = None + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + orig_size = len(prompts_text) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.vllm_tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts_text = prompts_text + + dedup_mapping = None + if n > 1 and all_prompts_text: + dedup_result = self._deduplicate_prompts(all_prompts_text, n) + if dedup_result is not None: + gen_prompts, dedup_mapping = dedup_result + gen_n = n + else: + gen_prompts = all_prompts_text + gen_n = 1 + else: + gen_prompts = all_prompts_text + gen_n = 1 + + sampling_params = SamplingParams( + n=gen_n, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_tokens, + structured_outputs=structured_outputs, + ) + + if gen_prompts: + all_outputs = self.vllm_engine.generate(gen_prompts, sampling_params=sampling_params, use_tqdm=False) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + else: + completion_ids = [] + + if dedup_mapping is not None: + completion_ids = [completion_ids[uid * gen_n + gid] for uid, gid in dedup_mapping] + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + + if self.vllm_enable_sleep_mode: + self.vllm_engine.sleep(level=2) + + return completion_ids + + def _generate_non_vllm_for_slices(self, slices: list[dict[str, torch.Tensor | Any]], on_policy_indices: list[int]): + """Fallback generation without vLLM (uses model.generate per slice).""" + with unwrap_model_for_generation( + self.model, + self.accelerator, + generation_kwargs=self.generation_kwargs, + ) as unwrapped_model: + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + result = self.generate_on_policy_outputs( + unwrapped_model, + slice_inputs, + self.generation_config, + self.processing_class.pad_token_id, + ) + new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + + updated_slice = dict(slice_inputs) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + updated_slice["original_prompt_text"] = prompt_texts + updated_slice["original_completion_text"] = completion_texts + + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_texts, completion_texts) + + def _process_completions_to_buffer( + self, + slices: list[dict[str, torch.Tensor | Any]], + on_policy_indices: list[int], + local_slice_indices: list[int], + completion_ids: list, + prompts_text: list[str], + prompts_text_with_special: list[str], + max_completion_length: int, + ): + """ + Process vLLM completions and update buffered inputs for on-policy slices. + """ + device = self.accelerator.device + pad_token_id = self.processing_class.pad_token_id if self.processing_class.pad_token_id is not None else 0 + + slice_completions = {idx: [] for idx in on_policy_indices} + slice_prompts = {idx: [] for idx in on_policy_indices} + slice_prompts_special = {idx: [] for idx in on_policy_indices} + + for i, slice_idx in enumerate(local_slice_indices): + slice_completions[slice_idx].append(completion_ids[i]) + slice_prompts[slice_idx].append(prompts_text[i]) + slice_prompts_special[slice_idx].append(prompts_text_with_special[i]) + + for slice_idx in on_policy_indices: + slice_inputs = slices[slice_idx] + completion_ids_for_slice = slice_completions[slice_idx] + prompt_txts = slice_prompts[slice_idx] + prompt_txts_with_special = slice_prompts_special[slice_idx] + + prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None + prompt_tokenized = self.processing_class( + prompt_txts, + return_tensors="pt", + padding="longest", + padding_side="left", + truncation=True if prompt_max_length else False, + max_length=prompt_max_length, + add_special_tokens=False, + ).to(device) + prompt_ids = prompt_tokenized.input_ids + + completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids_for_slice] + completion_ids_for_text: list[list[int]] = [] + padded_completion_ids_list = [] + for completion_tensor in completion_ids_tensors: + if len(completion_tensor) > max_completion_length: + truncated_completion_tensor = completion_tensor[:max_completion_length] + padded_completion_ids_list.append(truncated_completion_tensor) + completion_ids_for_text.append(truncated_completion_tensor.tolist()) + elif len(completion_tensor) < max_completion_length: + padding_needed = max_completion_length - len(completion_tensor) + padded_tensor = torch.cat( + [ + completion_tensor, + torch.full( + (padding_needed,), + pad_token_id, + device=device, + dtype=completion_tensor.dtype, + ), + ] + ) + padded_completion_ids_list.append(padded_tensor) + completion_ids_for_text.append(completion_tensor.tolist()) + else: + padded_completion_ids_list.append(completion_tensor) + completion_ids_for_text.append(completion_tensor.tolist()) + + completion_ids_padded = torch.stack(padded_completion_ids_list) + + new_input_ids = torch.cat([prompt_ids, completion_ids_padded], dim=1) + prompt_lengths = torch.full((prompt_ids.shape[0],), prompt_ids.shape[1], device=device) + new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) + + completion_texts = self.processing_class.batch_decode( + completion_ids_for_text, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + updated_slice = dict(slice_inputs) + updated_slice["input_ids"] = new_input_ids + updated_slice["attention_mask"] = new_attention_mask + updated_slice["labels"] = new_labels + updated_slice["original_prompt_text"] = prompt_txts_with_special + updated_slice["original_completion_text"] = completion_texts + + self._buffered_inputs[slice_idx] = updated_slice + self._buffered_text_logs[slice_idx] = (prompt_txts, completion_texts) + def _prepare_dataset( self, dataset: Dataset | IterableDataset, @@ -1053,24 +1570,15 @@ def _prepare_dataset( formatting_func: Callable[[dict], str] | None, dataset_name: str, ) -> Dataset | IterableDataset: - """ - Override dataset preparation to preserve original text for cross-tokenizer distillation and ensure - attention_mask is always added for DataCollatorForChatML compatibility. - """ - # Check if dataset is already processed + """Preserve original text fields for ULD when needed.""" column_names = list(next(iter(dataset)).keys()) is_processed = "input_ids" in column_names - # Use our enhanced dataset preparation for: - # 1. ULD loss with cross-tokenizer (need original text preservation) - # 2. Any unprocessed dataset (need attention_mask for DataCollatorForChatML) if not is_processed or (self.use_uld_loss and self.teacher_tokenizer is not None): - # For unprocessed datasets, use our enhanced tokenization return self._prepare_dataset_with_original_text( dataset, processing_class, args, packing, formatting_func, dataset_name ) - # Use parent implementation for all other cases return super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) def _prepare_dataset_with_original_text( @@ -1296,6 +1804,7 @@ def tokenize_with_original_text(example, processing_class, dataset_text_field, a "attention_mask", "position_ids", "completion_mask", + "messages", "assistant_masks", "original_prompt_text", "original_completion_text", @@ -1468,25 +1977,23 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N use_cache=False, ) - # hidden states (shifted) student_hidden = student_outputs.last_hidden_state[:, :-1] teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] - # Release full outputs to free memory del student_outputs, teacher_outputs - # labels mask and labels (shifted) + student_hidden = student_hidden.reshape(-1, student_hidden.shape[-1]) + teacher_hidden = teacher_hidden.reshape(-1, teacher_hidden.shape[-1]) + labels_mask = inputs["labels"] != -100 masked_input_ids = torch.where( labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) ) - true_labels = masked_input_ids[:, 1:].contiguous() + true_labels = masked_input_ids[:, 1:].contiguous().reshape(-1) - # heads student_head = unwrapped_student.get_output_embeddings() teacher_head = unwrapped_teacher.get_output_embeddings() - # liger fused jsd loss loss = self.liger_jsd_loss( student_input=student_hidden, student_weight=student_head.weight, @@ -1497,10 +2004,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_bias=getattr(teacher_head, "bias", None), ) - # Release hidden states after loss computation del student_hidden, teacher_hidden, true_labels else: - # Original behavior for same tokenizer or when teacher_tokenizer is not provided outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -1522,21 +2027,19 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_logits=shifted_teacher_logits, labels=shifted_labels, beta=self.beta, + temperature=self.temperature, ) if self.use_uld_loss: student_input_ids = inputs["input_ids"] - # Use the *teacher* labels created above, not the student's. teacher_labels_for_loss = teacher_labels if "teacher_labels" in locals() else inputs["labels"] teacher_input_ids_for_loss = teacher_input_ids if "teacher_input_ids" in locals() else inputs["input_ids"] - # Create properly masked student labels (fixing batch size > 1 issue) student_labels = inputs["labels"].clone() if hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None: student_labels[student_labels == self.processing_class.pad_token_id] = -100 - # Also mask pad tokens in teacher labels for consistency if ( hasattr(self, "teacher_tokenizer") and hasattr(self.teacher_tokenizer, "pad_token_id") @@ -1553,15 +2056,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_input_ids=teacher_input_ids_for_loss, ) - # If ULD hybrid mode produced per-step matched/unmatched components, accumulate them for logging. - # Use gradient_accumulation_steps to mirror Trainer's windowing behavior. if hasattr(self.uld_loss_fn, "last_matched_loss") and hasattr(self.uld_loss_fn, "last_unmatched_loss"): - try: - ga = max(1, int(self.args.gradient_accumulation_steps)) - except Exception: - ga = 1 + ga = max(1, int(self.args.gradient_accumulation_steps)) step_eq = 1.0 / ga - # read scalar values for logging matched_val = ( self.uld_loss_fn.last_matched_loss.item() if self.uld_loss_fn.last_matched_loss is not None @@ -1620,31 +2117,21 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token prompt_mask = inputs.get("prompt_attention_mask") pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id - if prompt_mask is not None: - prompt_lengths = prompt_mask.sum(dim=1).to(torch.long) + if self.use_transformers_paged: + # generate_batch() returns completion-only tokens, so the entire tensor is completion. + prompt_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) else: - if pad_token_id is not None: - prompt_lengths = (inputs["prompts"] != pad_token_id).sum(dim=1).to(torch.long) - else: - prompt_lengths = torch.full( - (batch_size,), - inputs["prompts"].shape[1], - dtype=torch.long, - device=device, - ) + # model.generate() returns full sequences (prompt + completion), so completions start + # after the full padded prompt width. + prompt_lengths = torch.full( + (batch_size,), + inputs["prompts"].shape[1], + dtype=torch.long, + device=device, + ) new_input_ids = generated_tokens - new_attention_mask = torch.ones_like(new_input_ids) - if pad_token_id is not None: - new_attention_mask[new_input_ids == pad_token_id] = 0 - - new_labels = torch.full_like(new_input_ids, -100) - for idx in range(batch_size): - length = int(prompt_lengths[idx].item()) - new_labels[idx, length:] = new_input_ids[idx, length:] - - if pad_token_id is not None: - new_labels[new_input_ids == pad_token_id] = -100 + new_attention_mask, new_labels = self._build_sequence_batch(new_input_ids, prompt_lengths, pad_token_id) prompt_texts = [] completion_texts = [] @@ -1673,174 +2160,6 @@ def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token return new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts - @profiling_decorator - def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_id=None): - device = self.accelerator.device - - # Decode prompts for vLLM (without special tokens - vLLM expects clean text) - prompts_text_for_vllm = self.processing_class.batch_decode( - inputs["prompts"], - skip_special_tokens=True, - # clean_up_tokenization_spaces=False # Keep this commented unless specific issues arise - ) - # Remove padding token text if it appears, as vLLM expects clean prompts - if self.processing_class.pad_token: - prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] - - # Also decode prompts WITH special tokens for ULD loss computation - prompts_text_with_special = self.processing_class.batch_decode( - inputs["prompts"], - skip_special_tokens=False, - ) - - # system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." - # target_system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." - # prompts_text = [p.replace(target_system_prompt, system_prompt) for p in prompts_text] - # Add system prompt to prompts - - max_completion_length = generation_config.max_new_tokens - temperature = generation_config.temperature - # vLLM uses top_k=-1 for no top_k, transformers uses 0 or None. - top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1 - # top_p, repetition_penalty, min_p are not directly in generation_config, get from trainer args - top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 - repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 - min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 - - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text_for_vllm) - if self.accelerator.is_main_process: - completion_ids = self.vllm_client.generate( - prompts=all_prompts_text, - n=1, # In GKD, we generate 1 completion per prompt from student - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_completion_length, - structured_outputs_regex=self.vllm_structured_outputs_regex, - )["completion_ids"] - else: - completion_ids = [None] * len(all_prompts_text) - completion_ids = broadcast_object_list(completion_ids, from_process=0) - process_slice = slice( - self.accelerator.process_index * len(prompts_text_for_vllm), - (self.accelerator.process_index + 1) * len(prompts_text_for_vllm), - ) - completion_ids = completion_ids[process_slice] - elif self.vllm_mode == "colocate": - if self.vllm_structured_outputs_regex is not None: - structured_outputs = StructuredOutputsParams( - backend="outlines", regex=self.vllm_structured_outputs_regex - ) - else: - structured_outputs = None - sampling_params = SamplingParams( - n=1, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_completion_length, - structured_outputs=structured_outputs, - ) - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text_for_vllm) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text_for_vllm, group=self.vllm_tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts_text = prompts_text_for_vllm - - all_outputs = self.vllm_engine.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - - if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - - if self.vllm_enable_sleep_mode: - self.vllm_engine.sleep(level=2) - else: - raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") - - # We need to combine prompt and completion for new_input_ids - # Tokenize prompts again to get prompt_ids on the correct device and format - # Use prompts_text_for_vllm (without special tokens) for tokenization since vLLM expects clean text - # Ensure add_special_tokens=False as vLLM typically handles prompts as raw text - # Calculate max_length for prompts, ensuring it's positive - prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None - prompt_tokenized = self.processing_class( - prompts_text_for_vllm, - return_tensors="pt", - padding="longest", - truncation=True if prompt_max_length else False, - max_length=prompt_max_length, - add_special_tokens=False, - ).to(device) - prompt_ids = prompt_tokenized.input_ids - - completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids] - # Manually pad/truncate completions to max_completion_length length before using pad function - padded_completion_ids_list = [] - for completion_tensor in completion_ids_tensors: - if len(completion_tensor) > max_completion_length: - # Truncate if longer than max_completion_length - padded_completion_ids_list.append(completion_tensor[:max_completion_length]) - elif len(completion_tensor) < max_completion_length: - # Pad if shorter than max_completion_length - padding_needed = max_completion_length - len(completion_tensor) - padded_tensor = torch.cat( - [ - completion_tensor, - torch.full((padding_needed,), pad_token_id, device=device, dtype=completion_tensor.dtype), - ] - ) - padded_completion_ids_list.append(padded_tensor) - else: - # Already the right length - padded_completion_ids_list.append(completion_tensor) - - # Now all tensors are the same length, so we can stack them - padded_completion_ids = torch.stack(padded_completion_ids_list) - - # Ensure prompt_ids and padded_completion_ids are 2D - if prompt_ids.ndim == 1: - prompt_ids = prompt_ids.unsqueeze(0) - if padded_completion_ids.ndim == 1: - padded_completion_ids = padded_completion_ids.unsqueeze(0) - - new_input_ids = torch.cat([prompt_ids, padded_completion_ids], dim=1) - - new_attention_mask = torch.ones_like(new_input_ids, device=device) - new_labels = new_input_ids.clone() - - if pad_token_id is not None: - new_labels[new_labels == pad_token_id] = -100 - new_attention_mask[new_input_ids == pad_token_id] = 0 - - # Mask prompt tokens in labels - prompt_lengths = prompt_ids.shape[1] - new_labels[:, :prompt_lengths] = -100 - - # IMPORTANT: Preserve original text for cross-tokenizer ULD loss - # Use prompts_text_with_special (with special tokens) for ULD loss computation - # Extract completion texts from the generated completion IDs - completion_texts = [] - for comp_ids in completion_ids: - completion_text = self.processing_class.decode(comp_ids, skip_special_tokens=False) - completion_texts.append(completion_text) - - return new_input_ids, new_attention_mask, new_labels, prompts_text_with_special, completion_texts - def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM.""" if visited is None: @@ -1943,6 +2262,27 @@ def _wake_vllm_if_needed(self): empty_cache() self.vllm_engine.wake_up(tags=["kv_cache"]) + def _get_liger_zero3_lm_head_gather_ctx(self, model: nn.Module): + if not self.use_liger_gkd_loss: + return nullcontext() + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + if deepspeed_plugin is None or deepspeed_plugin.zero_stage != 3: + return nullcontext() + + import deepspeed + + unwrapped_student = self.accelerator.unwrap_model(model) + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + params = [student_head.weight, teacher_head.weight] + if student_head.bias is not None: + params.append(student_head.bias) + if teacher_head.bias is not None: + params.append(teacher_head.bias) + return deepspeed.zero.GatheredParameters(params, modifier_rank=None) + @profiling_decorator def training_step( self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None @@ -1954,47 +2294,25 @@ def training_step( `self.lmbda`, it generates new responses using the student model, which are then used for training instead of the offline original inputs. """ - on_policy = False - if random.random() <= self.lmbda: - on_policy = True - if self.use_vllm: - self._wake_vllm_if_needed() - result = self._generate_on_policy_outputs_vllm( - inputs, self.generation_config, self.processing_class.pad_token_id - ) - new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result - else: - with ( - unwrap_model_for_generation( - model, - self.accelerator, - generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 - ) as unwrapped_model - ): - result = self.generate_on_policy_outputs( - unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id - ) - new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + buffer_steps = self.args.gradient_accumulation_steps - inputs["input_ids"] = new_input_ids - inputs["attention_mask"] = new_attention_mask - inputs["labels"] = new_labels + # Keep lm_head gathered across forward+backward for Liger + ZeRO-3. + with self._get_liger_zero3_lm_head_gather_ctx(model): + loss = super().training_step(model, inputs, num_items_in_batch) - # CRITICAL: Preserve original text for cross-tokenizer ULD loss - # This ensures both off-policy (dataset) and on-policy (generated) samples - # can use proper text-based alignment for different tokenizers - inputs["original_prompt_text"] = prompt_texts - inputs["original_completion_text"] = completion_texts + slice_idx = (self._step - 1) % buffer_steps - # Log prompt and completion texts + on_policy = False + if self._buffered_on_policy is not None and slice_idx < len(self._buffered_on_policy): + on_policy = self._buffered_on_policy[slice_idx] + + if on_policy and self._buffered_text_logs is not None and self._buffered_text_logs[slice_idx] is not None: + prompt_texts, completion_texts = self._buffered_text_logs[slice_idx] self._textual_logs["prompt"].extend(gather_object(prompt_texts)) self._textual_logs["completion"].extend(gather_object(completion_texts)) - loss = super().training_step(model, inputs, num_items_in_batch) - loss_scalar = float(loss.detach()) - ga = max(1, int(self.args.gradient_accumulation_steps)) - step_equiv = 1.0 / ga + step_equiv = 1.0 / self.args.gradient_accumulation_steps if on_policy: self._on_policy_loss_total += loss_scalar @@ -2010,7 +2328,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if mode == "train": device = self.accelerator.device if hasattr(self.accelerator, "device") else torch.device("cpu") - # include matched/unmatched accumulators for distributed reduction vec = torch.tensor( [ self._on_policy_loss_total, @@ -2026,7 +2343,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: device=device, ) - # Sum across processes so we mirror Trainer's distributed reduction if ( getattr(self.accelerator, "distributed_type", DistributedType.NO) != DistributedType.NO and dist.is_available() @@ -2045,20 +2361,16 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: unmatched_eq, ) = vec.tolist() - # Compute category averages over the *same window* as Trainer's logs - # (avoid div-by-zero if, e.g., no on-policy steps in the window) if on_eq > 0: logs["on_policy_loss"] = round(on_sum / on_eq, 4) if off_eq > 0: logs["off_policy_loss"] = round(off_sum / off_eq, 4) - # matched/unmatched averaged over same logging window (if present) if matched_eq > 0: logs["matched_loss"] = round(matched_sum / matched_eq, 4) if unmatched_eq > 0: logs["unmatched_loss"] = round(unmatched_sum / unmatched_eq, 4) - # Reset window accumulators after logging (just like Trainer resets its window) self._on_policy_loss_total = self._off_policy_loss_total = 0.0 self._on_policy_step_equiv = self._off_policy_step_equiv = 0.0 self._matched_sum = self._unmatched_sum = 0.0