-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Align GRPO and RLOO initialization #4685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4d7345a
align rloo and grpo
qgallouedec 21549b2
style
qgallouedec c336c9a
Merge branch 'main' into align-rloo
qgallouedec 8be5fc0
Move `prepare_model_for_kbit_training`, `enable_gradient_checkpointin…
qgallouedec c0f649a
Merge branch 'main' into align-rloo
qgallouedec 5d42fcd
Merge branch 'main' into align-rloo
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,15 +27,13 @@ | |
| import pandas as pd | ||
| import torch | ||
| import torch.utils.data | ||
| import transformers | ||
| from accelerate import logging | ||
| from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed | ||
| from datasets import Dataset, IterableDataset | ||
| from torch import nn | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.utils.data import DataLoader, Sampler | ||
| from transformers import ( | ||
| AutoConfig, | ||
| AutoModelForSequenceClassification, | ||
| AutoProcessor, | ||
| AutoTokenizer, | ||
|
|
@@ -60,13 +58,14 @@ | |
| from ..extras.profiling import profiling_context, profiling_decorator | ||
| from ..extras.vllm_client import VLLMClient | ||
| from ..import_utils import is_vllm_available | ||
| from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation | ||
| from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation | ||
| from ..models.utils import disable_gradient_checkpointing | ||
| from .base_trainer import BaseTrainer | ||
| from .callbacks import SyncRefModelCallback | ||
| from .rloo_config import RLOOConfig | ||
| from .utils import ( | ||
| RepeatSampler, | ||
| create_model_from_path, | ||
| disable_dropout_in_model, | ||
| ensure_master_addr_port, | ||
| entropy_from_logits, | ||
|
|
@@ -86,7 +85,7 @@ | |
|
|
||
|
|
||
| if is_peft_available(): | ||
| from peft import PeftConfig, PeftModel | ||
| from peft import PeftConfig, PeftModel, get_peft_model | ||
|
|
||
| if is_vllm_available(): | ||
| from vllm import LLM, SamplingParams | ||
|
|
@@ -120,21 +119,15 @@ class RLOOTrainer(BaseTrainer): | |
| ```python | ||
| from datasets import load_dataset | ||
| from trl import RLOOTrainer | ||
| from trl.rewards import accuracy_reward | ||
|
|
||
| dataset = load_dataset("trl-lib/tldr", split="train") | ||
|
|
||
|
|
||
| def reward_func(completions, **kwargs): | ||
| # Dummy reward function that rewards completions with more unique letters. | ||
| return [float(len(set(completion))) for completion in completions] | ||
|
|
||
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") | ||
|
|
||
| trainer = RLOOTrainer( | ||
| model="Qwen/Qwen2-0.5B-Instruct", | ||
| reward_funcs=reward_func, | ||
| reward_funcs=accuracy_reward, | ||
| train_dataset=dataset, | ||
| ) | ||
|
|
||
| trainer.train() | ||
| ``` | ||
|
|
||
|
|
@@ -231,8 +224,8 @@ def reward_func(completions, **kwargs): | |
|
|
||
| def __init__( | ||
| self, | ||
| model: str | PreTrainedModel = None, | ||
| reward_funcs: RewardFunc | list[RewardFunc] = None, | ||
| model: str | PreTrainedModel, | ||
| reward_funcs: RewardFunc | list[RewardFunc], | ||
| args: RLOOConfig | None = None, | ||
| train_dataset: Dataset | IterableDataset | None = None, | ||
| eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, | ||
|
|
@@ -248,28 +241,14 @@ def __init__( | |
| model_name = model_name.split("/")[-1] | ||
| args = RLOOConfig(f"{model_name}-RLOO") | ||
|
|
||
| # Models | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all other changes come from #4577 |
||
| # Trained model | ||
| model_init_kwargs = args.model_init_kwargs or {} | ||
| # Model | ||
| if isinstance(model, str): | ||
| model_id = model | ||
| dtype = model_init_kwargs.get("dtype", "auto") | ||
| if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: | ||
| pass # dtype is already a torch.dtype or "auto" or None | ||
| elif isinstance(dtype, str): # it's a str, but not "auto" | ||
| dtype = getattr(torch, dtype) | ||
| model_init_kwargs["dtype"] = dtype | ||
| else: | ||
| raise ValueError( | ||
| "Invalid `dtype` passed to `RLOOConfig`. Expected either 'auto' or a string representing " | ||
| f"a `torch.dtype` (e.g., 'float32'), but got {dtype}." | ||
| ) | ||
| model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") | ||
| config = AutoConfig.from_pretrained(model_id) | ||
| architecture = getattr(transformers, config.architectures[0]) | ||
| model = architecture.from_pretrained(model_id, **model_init_kwargs) | ||
| model_init_kwargs = args.model_init_kwargs or {} | ||
| # Special case for DeepSpeed: requires device_map=None ("auto" fails) | ||
| if args.distributed_state.distributed_type == "DEEPSPEED": | ||
| model_init_kwargs["device_map"] = None | ||
| model = create_model_from_path(model, **model_init_kwargs) | ||
| else: | ||
| model_id = get_config_model_id(model.config) | ||
| if args.model_init_kwargs is not None: | ||
| logger.warning( | ||
| "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. " | ||
|
|
@@ -284,12 +263,11 @@ def __init__( | |
| else inspect.signature(model.get_base_model().forward).parameters.keys() | ||
| ) | ||
|
|
||
| if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): | ||
| model = prepare_peft_model(model, peft_config, args) | ||
|
|
||
| # Processing class | ||
| if processing_class is None: | ||
| processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left") | ||
| processing_class = AutoProcessor.from_pretrained( | ||
| get_config_model_id(model.config), truncation_side="left", padding_side="left" | ||
| ) | ||
|
|
||
| # Handle pad token for processors or tokenizers | ||
| if isinstance(processing_class, ProcessorMixin): | ||
|
|
@@ -306,12 +284,40 @@ def __init__( | |
| self.pad_token_id = tokenizer.pad_token_id | ||
| self.eos_token_id = tokenizer.eos_token_id | ||
|
|
||
| if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: | ||
| # If the model is already a PeftModel, we need to merge and unload it. | ||
| # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft | ||
| model = model.merge_and_unload() | ||
|
|
||
| # Create PEFT model | ||
| if peft_config is not None: | ||
| model = get_peft_model(model, peft_config) | ||
|
|
||
| # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally | ||
| # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 | ||
| if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: | ||
| model.enable_input_require_grads() | ||
|
|
||
| # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the | ||
| # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by | ||
| # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for | ||
| # quantized models. See: https://github.com/huggingface/peft/issues/2889 | ||
| # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do | ||
| if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): | ||
| for param in model.parameters(): | ||
| if param.requires_grad: | ||
| param.data = param.data.to(torch.bfloat16) | ||
|
|
||
| # Reward functions | ||
| if not isinstance(reward_funcs, list): | ||
| reward_funcs = [reward_funcs] | ||
| self.reward_func_names = [] | ||
| for i, reward_func in enumerate(reward_funcs): | ||
| if isinstance(reward_func, str): | ||
| model_init_kwargs = args.model_init_kwargs or {} | ||
| # Special case for DeepSpeed: requires device_map=None ("auto" fails) | ||
| if args.distributed_state.distributed_type == "DEEPSPEED": | ||
| model_init_kwargs["device_map"] = None | ||
| reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( | ||
| reward_func, num_labels=1, **model_init_kwargs | ||
| ) | ||
|
|
@@ -362,7 +368,7 @@ def __init__( | |
| self.max_prompt_length = args.max_prompt_length | ||
| self.max_completion_length = args.max_completion_length | ||
| self.num_generations = args.num_generations | ||
| self.num_generations_eval = args.num_generations_eval or args.num_generations | ||
| self.num_generations_eval = args.num_generations_eval or self.num_generations | ||
| self.chat_template_kwargs = args.chat_template_kwargs or {} | ||
| self.temperature = args.temperature | ||
| self.top_p = args.top_p | ||
|
|
@@ -433,9 +439,11 @@ def __init__( | |
| self.ref_model = None | ||
| else: | ||
| # For deepspeed, fsdp or non-distributed models, create a reference model from scratch | ||
| config = AutoConfig.from_pretrained(model_id) | ||
| architecture = getattr(transformers, config.architectures[0]) | ||
| self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) | ||
| model_init_kwargs = args.model_init_kwargs or {} | ||
| # Special case for DeepSpeed: requires device_map=None ("auto" fails) | ||
| if self.args.distributed_state.distributed_type == "DEEPSPEED": | ||
| model_init_kwargs["device_map"] = None | ||
| self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) | ||
|
|
||
| # Disable dropout in the models | ||
| if args.disable_dropout: | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see #4524