Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ jobs:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install -e .[test,geo,vllm]
pip3 install -e .[test,gpu,vllm,geo,trl]
# Geo3k
- name: Prepare Geo3k dataset
run: |
ray stop --force
python3 examples/data_preprocess/geo3k.py
- name: Running Geo3k VLM E2E training tests on 8 L20 GPUs with rmpad using function rm
- name: Running Geo3k VLM GRPO E2E training tests on 8 L20 GPUs with rmpad using function rm
run: |
ray stop --force
TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
Expand All @@ -193,6 +193,16 @@ jobs:
SP_SIZE=2 \
bash tests/e2e/ppo_trainer/run_function_reward.sh

- name: Running Geo3k VLM PPO E2E training tests on 8 L20 GPUs with rmpad using function rm
run: |
ray stop --force
TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \
ADV_ESTIMATOR=gae RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
SP_SIZE=2 \
bash tests/e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_sglang:
runs-on: [L20x8]
needs: pre_commit_for_ppo
Expand Down Expand Up @@ -360,4 +370,4 @@ jobs:
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \
ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \
bash tests/e2e/ppo_trainer/run_function_reward.sh
bash tests/e2e/ppo_trainer/run_function_reward.sh
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"torch-memory-saver>=0.0.5",
"torch==2.6.0",
]
TRL_REQUIRES = ["trl<=0.9.6"]

extras_require = {
"test": TEST_REQUIRES,
Expand All @@ -64,6 +65,7 @@
"math": MATH_REQUIRES,
"vllm": VLLM_REQUIRES,
"sglang": SGLANG_REQUIRES,
"trl": TRL_REQUIRES,
}


Expand Down
11 changes: 11 additions & 0 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel

from verl.utils.import_utils import is_trl_available
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
Expand Down Expand Up @@ -156,6 +157,16 @@ def apply_monkey_patch(
assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (
f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness."
)

if is_trl_available():
from trl import AutoModelForCausalLMWithValueHead

def state_dict(self, *args, **kwargs):
return torch.nn.Module.state_dict(self, *args, **kwargs)

AutoModelForCausalLMWithValueHead.state_dict = state_dict
print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ")

# TODO: VLM models only, unify monkey patch to LLM models.
if model.config.model_type == "qwen2_5_vl":
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Expand Down
9 changes: 9 additions & 0 deletions verl/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def is_sglang_available():
return sglang_spec is not None


@cache
def is_trl_available():
try:
trl_spec = importlib.util.find_spec("trl")
except ModuleNotFoundError:
trl_spec = None
return trl_spec is not None


def import_external_libs(external_libs=None):
if external_libs is None:
return
Expand Down
68 changes: 68 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)

from verl.models.registry import ModelRegistry
from verl.utils.import_utils import is_trl_available


class LambdaLayer(nn.Module):
Expand Down Expand Up @@ -446,3 +447,70 @@ def get_parallel_gptmodel_from_config(tfconfig, hf_config, pre_process=None, pos

parallel_model.output_layer = LinearForLastLayer(input_size=tfconfig.hidden_size, output_size=1, config=tfconfig)
return parallel_model


def patch_valuehead_model(model) -> None:
from types import MethodType

from transformers import PreTrainedModel

from trl import AutoModelForCausalLMWithValueHead

def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()

def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()

def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_output_embeddings()

def can_generate(self):
return False

ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
setattr(model, "can_generate", MethodType(can_generate, model))
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", []))


def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code):
from transformers import AutoModelForTokenClassification, AutoModelForCausalLM, AutoModelForVision2Seq

try:
model = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
return model
except BaseException as e:
if not is_trl_available():
raise RuntimeError(f"model({local_path}) is not a value head model, please install trl to make it valid") from e

assert is_trl_available()

from trl import AutoModelForCausalLMWithValueHead

if type(model_config) in AutoModelForVision2Seq._model_mapping.keys():
module_class = AutoModelForVision2Seq
else:
module_class = AutoModelForCausalLM
ori_model = module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model)
patch_valuehead_model(model)
return model
17 changes: 13 additions & 4 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,13 @@ def _forward_micro_batch(self, micro_batch):
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)

if hasattr(self.critic_module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values_rmpad = output[2].squeeze(0).unsqueeze(-1)
else:
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)

# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
Expand All @@ -112,7 +117,11 @@ def _forward_micro_batch(self, micro_batch):
**multi_modal_inputs,
use_cache=False,
) # prevent model thinks we are generating
values = output.logits
if hasattr(self.critic_module, "v_head"):
# For trl.AutoModelForCausalLMWithValueHead
values = output[2]
else:
values = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1)
return values

Expand Down Expand Up @@ -206,7 +215,7 @@ def update_critic(self, data: DataProto):
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu

self.critic_optimizer.zero_grad()

Expand Down
14 changes: 8 additions & 6 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _build_critic_model_optimizer(self, config):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision

from verl.utils.model import print_model_size
from verl.utils.model import load_valuehead_model, print_model_size
from verl.utils.torch_dtypes import PrecisionType

use_shm = config.model.get('use_shm', False)
Expand Down Expand Up @@ -881,11 +881,13 @@ def _build_critic_model_optimizer(self, config):
warnings.simplefilter("ignore")
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = "0"
critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
trust_remote_code=config.model.get("trust_remote_code", False),
critic_model_config.summary_dropout_prob = 0.0

critic_module = load_valuehead_model(
local_path,
torch_dtype,
critic_model_config,
config.model.get("trust_remote_code", False),
)

use_remove_padding = config.model.get("use_remove_padding", False)
Expand Down
Loading