diff --git a/pyproject.toml b/pyproject.toml index 1b84ff50d7f..61e8c0f4726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ kernels = [ "kernels" ] liger = [ - "liger-kernel>=0.6.2" + "liger-kernel>=0.6.4" ] peft = [ "peft>=0.8.0" @@ -104,7 +104,7 @@ dev = [ # kernels "kernels", # liger - "liger-kernel>=0.6.2", + "liger-kernel>=0.6.4", # peft "peft>=0.8.0", # quality diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b3844a399c1..c747e0a1aa6 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1518,7 +1518,6 @@ def reward_func(completions, **kwargs): num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage use_liger_kernel=True, # enable Liger kernel - loss_type="bnpo", # default dapo is not supported yet report_to="none", ) trainer = GRPOTrainer( @@ -1839,7 +1838,6 @@ def test_training_with_liger_grpo_kernel(self, model_name): max_completion_length=self.max_length, report_to="none", logging_strategy="no", - loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620 ) model = AutoModelForCausalLM.from_pretrained(model_name) @@ -1888,7 +1886,6 @@ def test_training_with_liger_grpo_kernel_and_peft(self, model_name): max_completion_length=self.max_length, report_to="none", logging_strategy="no", - loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620 ) model = AutoModelForCausalLM.from_pretrained(model_name) diff --git a/trl/import_utils.py b/trl/import_utils.py index 4d8a9c84ce0..7a062464ada 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -23,7 +23,7 @@ from transformers.utils.import_utils import _is_package_available -LIGER_KERNEL_MIN_VERSION = "0.5.8" +LIGER_KERNEL_MIN_VERSION = "0.6.4" # Use same as transformers.utils.import_utils _deepspeed_available = _is_package_available("deepspeed") diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index d6c4e4501ea..c33781d608a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -822,18 +822,19 @@ def __init__( ) # Loss function - if args.loss_type == "nll": - pass # use the default loss - elif args.loss_type == "dft": - if compute_loss_func is not None: - raise ValueError( - "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " - "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a " - "`compute_loss_func` is not allowed." - ) - compute_loss_func = dft_loss - else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so " + "passing a `compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} @@ -1113,6 +1114,11 @@ def compute_loss( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False + # Request token accuracy from Liger kernel and set token scaling if using DFT loss + if self.args.use_liger_kernel: + inputs["return_token_accuracy"] = True + inputs["use_token_scaling"] = self.args.loss_type == "dft" + (loss, outputs) = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) @@ -1151,8 +1157,11 @@ def compute_loss( self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - if not self.args.use_liger_kernel: + if self.args.use_liger_kernel: + token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() + self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) + else: + # Compute accuracy from logits using argmax (traditional method) with torch.no_grad(): if "shift_labels" in inputs: # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: @@ -1190,10 +1199,12 @@ def compute_loss( total_sum = total_tokens.sum() accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 self._metrics[mode]["mean_token_accuracy"].append(accuracy) - if self.aux_loss_enabled: - aux_loss = outputs.aux_loss - aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() - self._metrics[mode]["aux_loss"].append(aux_loss) + + # Log auxiliary loss if enabled (applies to both Liger and non-Liger) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) return (loss, outputs) if return_outputs else loss