Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
import pandas as pd
import torch
import torch.utils.data
import transformers
from accelerate.logging import get_logger
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
Expand All @@ -60,13 +58,14 @@
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
from ..models.utils import disable_gradient_checkpointing
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .rloo_config import RLOOConfig
from .utils import (
RepeatSampler,
create_model_from_path,
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
Expand All @@ -86,7 +85,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel
from peft import PeftConfig, PeftModel, get_peft_model

if is_vllm_available():
from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -225,8 +224,8 @@ class RLOOTrainer(BaseTrainer):

def __init__(
self,
model: str | PreTrainedModel = None,
reward_funcs: RewardFunc | list[RewardFunc] = None,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: RLOOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
Expand All @@ -242,28 +241,14 @@ def __init__(
model_name = model_name.split("/")[-1]
args = RLOOConfig(f"{model_name}-RLOO")

# Models

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all other changes come from #4577

# Trained model
model_init_kwargs = args.model_init_kwargs or {}
# Model
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype", "auto")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str): # it's a str, but not "auto"
dtype = getattr(torch, dtype)
model_init_kwargs["dtype"] = dtype
else:
raise ValueError(
"Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
else:
model_id = get_config_model_id(model.config)
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. "
Expand All @@ -278,12 +263,11 @@ def __init__(
else inspect.signature(model.get_base_model().forward).parameters.keys()
)

if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)

# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")
processing_class = AutoProcessor.from_pretrained(
get_config_model_id(model.config), truncation_side="left", padding_side="left"
)

# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
Expand All @@ -300,12 +284,40 @@ def __init__(
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
# If the model is already a PeftModel, we need to merge and unload it.
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
model = model.merge_and_unload()

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
# original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
# passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
# quantized models. See: https://github.com/huggingface/peft/issues/2889
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
for param in model.parameters():
if param.requires_grad:
param.data = param.data.to(torch.bfloat16)

# Reward functions
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
self.reward_func_names = []
for i, reward_func in enumerate(reward_funcs):
if isinstance(reward_func, str):
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
reward_func, num_labels=1, **model_init_kwargs
)
Expand Down Expand Up @@ -356,7 +368,7 @@ def __init__(
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length
self.num_generations = args.num_generations
self.num_generations_eval = args.num_generations_eval or args.num_generations
self.num_generations_eval = args.num_generations_eval or self.num_generations
self.chat_template_kwargs = args.chat_template_kwargs or {}
self.temperature = args.temperature
self.top_p = args.top_p
Expand Down Expand Up @@ -427,9 +439,11 @@ def __init__(
self.ref_model = None
else:
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
model_init_kwargs = args.model_init_kwargs or {}
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
if self.args.distributed_state.distributed_type == "DEEPSPEED":
model_init_kwargs["device_map"] = None
self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs)

# Disable dropout in the models
if args.disable_dropout:
Expand Down
Loading