Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/notebooks/grpo_agent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@
"\n",
" # Memory optimization\n",
" gradient_checkpointing = True, # Enable activation recomputation to save memory\n",
" gradient_checkpointing_kwargs = {\"use_reentrant\": False}, # Use non-reentrant checkpointing\n",
"\n",
" # Hub integration\n",
" push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@
"\n",
" use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n",
" gradient_checkpointing=True, # Save memory by recomputing activations during backpropagation\n",
" gradient_checkpointing_kwargs={\"use_reentrant\": False}, # Additional args to prevent warnings during gradient checkpointing\n",
")\n"
]
},
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/grpo_trl_lora_qlora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@
" optim = \"paged_adamw_8bit\", # Optimizer\n",
" use_liger_kernel=True, # Enable Liger kernel optimizations for faster training\n",
" gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation\n",
" gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n",
"\n",
" # Parameters related to reporting and saving\n",
" output_dir=output_dir, # Where to save model checkpoints and logs\n",
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/openenv_sudoku_grpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,6 @@
"\n",
" use_liger_kernel=False, # Enable Liger kernel optimizations for faster training\n",
" gradient_checkpointing=True, # Save memory by recomputing activations during backprop\n",
" gradient_checkpointing_kwargs={\"use_reentrant\": False}, # Prevent warnings with certain models\n",
" # chat_template_kwargs={\"enable_thinking\": False}, # Optional template args for model reasoning. We manage this in the rollout function\n",
"\n",
" temperature=0.8,\n",
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/openenv_wordle_grpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,6 @@
"\n",
" # Memory optimization\n",
" gradient_checkpointing = True, # Enable activation recomputation to save memory\n",
" gradient_checkpointing_kwargs = {\"use_reentrant\": False}, # Use non-reentrant checkpointing\n",
"\n",
" # Hub integration\n",
" push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n",
Expand Down
1 change: 0 additions & 1 deletion examples/notebooks/sft_trl_lora_qlora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@
" push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub\n",
" # The model will be saved under your Hub account in the repository named `output_dir`\n",
"\n",
" gradient_checkpointing_kwargs={\"use_reentrant\": False}, # To prevent warning message\n",
")"
]
},
Expand Down
9 changes: 7 additions & 2 deletions examples/scripts/openenv/browsergym.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@
from pathlib import Path

import numpy as np
from datasets import Dataset
from browsergym_env import BrowserGymAction, BrowserGymEnv
from datasets import Dataset
from PIL import Image
from transformers import AutoTokenizer

Expand All @@ -107,7 +107,12 @@ def parse_args() -> argparse.Namespace:
default="Qwen/Qwen3-VL-2B-Instruct",
help="Model identifier passed to GRPOTrainer for fine-tuning.",
)
parser.add_argument("--env-host", type=str, default="https://openenv-browsergym-env.hf.space", help="Host for the BrowserGym environment.")
parser.add_argument(
"--env-host",
type=str,
default="https://openenv-browsergym-env.hf.space",
help="Host for the BrowserGym environment.",
)
parser.add_argument("--env-port", type=int, default=8001, help="Port for the BrowserGym environment.")
parser.add_argument(
"--env-mode",
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/openenv/browsergym_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
from datetime import datetime
from pathlib import Path

from datasets import Dataset
from browsergym_env import BrowserGymAction, BrowserGymEnv
from datasets import Dataset
from transformers import AutoTokenizer

from trl import GRPOConfig, GRPOTrainer
Expand Down
4 changes: 1 addition & 3 deletions examples/scripts/openenv/sudoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def parse_args() -> argparse.Namespace:
# Environment
parser.add_argument("--env-host", type=str, default="https://openenv-sudoku.hf.space")
parser.add_argument("--env-port", type=int, default=8001)
parser.add_argument(
"--env-mode", choices=["docker-local", "docker-image", "docker-hub", "space"], default="space"
)
parser.add_argument("--env-mode", choices=["docker-local", "docker-image", "docker-hub", "space"], default="space")
parser.add_argument("--env-image", type=str, default="textarena-env:latest")

# Prompts
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def make_conversation(example):
output_dir="Qwen3-0.6B-RLOO",
model_init_kwargs={"dtype": torch.bfloat16},
learning_rate=1e-5,
gradient_checkpointing_kwargs=dict(use_reentrant=False),
log_completions=True,
num_completions_to_print=2,
max_prompt_length=2048,
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def main():
bf16=True,
use_liger_kernel=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=8192,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class CustomScriptArguments(ScriptArguments):
script_args, training_args, model_args = parser.parse_args_and_config()

# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False

# Load dataset
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.max_length = None

################
Expand Down
1 change: 0 additions & 1 deletion examples/scripts/sft_vlm_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
def main():
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.max_length = None

################
Expand Down
8 changes: 1 addition & 7 deletions tests/distributed/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,7 @@ def test_sft_dataset_streaming(self, config):
)
# fmt: on

@pytest.mark.parametrize(
"config",
[
pytest.param("ddp", marks=pytest.mark.xfail(reason="PEFT + multi-GPU is broken, see #4782")),
"fsdp2",
],
)
@pytest.mark.parametrize("config", ["ddp", "fsdp2"])
def test_sft_peft(self, config):
# fmt: off
run_command(
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/bco/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -219,4 +221,12 @@ class BCOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

super().__post_init__()
10 changes: 10 additions & 0 deletions trl/experimental/cpo/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -220,6 +222,14 @@ class CPOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

# Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0
if self.loss_type == "alphapo":
self.loss_type = "simpo"
Expand Down
2 changes: 1 addition & 1 deletion trl/experimental/gfpo/gfpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _generate_and_score_completions(self, inputs):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)

# When gradient checkpointing is enabled with use_reentrant=True (default), calling the model inside a
# When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
# torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
# Temporarily disable checkpointing to avoid this warning during inference.
with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _generate_and_score_completions(
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)

# When gradient checkpointing is enabled with use_reentrant=True (default), calling the model inside a
# When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
# torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
# Temporarily disable checkpointing to avoid this warning during inference.
with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/kto/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -219,6 +221,14 @@ class KTOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

if self.use_liger_loss is not None:
warnings.warn(
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/minillm/minillm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments

from ...trainer.grpo_config import GRPOConfig
Expand Down Expand Up @@ -88,6 +90,14 @@ def __post_init__(self):
# 1. num_generations can be < 2 in MiniLLMConfig. Scale_rewards must be set to "none" to avoid nan.
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

TrainingArguments.__post_init__(self)

self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards)
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/online_dpo/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -392,6 +394,14 @@ class may differ from those in [`~transformers.TrainingArguments`].
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

super().__post_init__()

if hasattr(self.beta, "__len__") and len(self.beta) == 1:
Expand Down
9 changes: 1 addition & 8 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,7 @@ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPO
else:
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)

if use_reentrant:
model.enable_input_require_grads()

model.enable_input_require_grads()
return model

def _generate_vllm(self, prompts, images=None):
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/orpo/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Any

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -176,4 +178,12 @@ class ORPOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

super().__post_init__()
10 changes: 10 additions & 0 deletions trl/experimental/ppo/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from dataclasses import dataclass, field
from typing import Literal

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -314,4 +316,12 @@ class PPOConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

super().__post_init__()
10 changes: 10 additions & 0 deletions trl/experimental/prm/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from dataclasses import dataclass, field

import transformers
from packaging.version import Version
from transformers import TrainingArguments


Expand Down Expand Up @@ -109,4 +111,12 @@ class PRMConfig(TrainingArguments):
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
if self.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
self.gradient_checkpointing_kwargs = self.gradient_checkpointing_kwargs or {}
self.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)

super().__post_init__()
2 changes: 1 addition & 1 deletion trl/extras/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def init_communicator(self, device: torch.device | str | int = 0):
host_name=self.host, port=self.group_port, world_size=world_size, is_master=(self.rank == 0)
)
prefixed_store = c10d.PrefixStore("client2server", store)
xccl_options = c10d.ProcessGroupXCCL.Options()
xccl_options = c10d.ProcessGroupXCCL.Options()
pg = c10d.ProcessGroupXCCL(
store=prefixed_store,
rank=self.rank,
Expand Down
Loading
Loading