Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/grpo_trainer/run_mistral3_1_24b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
set -x
ENGINE=${1:-vllm}
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
echo $HOME
train_files=$HOME/data/geo3k/train.parquet
test_files=$HOME/data/geo3k/test.parquet
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=${train_files} \
data.val_files=${test_files} \
data.train_batch_size=256 \
data.max_prompt_length=8196 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.image_key=images \
actor_rollout_ref.model.path=mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=5 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.actor.strategy=fsdp2 \
actor_rollout_ref.ref.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \
actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=5 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='mistral3_1_24b_ai' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
4 changes: 3 additions & 1 deletion recipe/entropy/main_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def run(self, config):
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

from verl.utils.dataset.rl_dataset import collate_fn
from verl.utils.dataset.collate_utils import get_collate_fn_manager_cls

collate_fn = get_collate_fn_manager_cls("default")

train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
Expand Down
3 changes: 2 additions & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.dataset.collate_utils import collate_fn
from verl.utils.dataset.rl_dataset import RLHFDataset
from verl.utils.metric import reduce_metrics
from verl.utils.profiler.performance import simple_timer

Expand Down
4 changes: 2 additions & 2 deletions recipe/spin/spin_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
from verl.utils.dataset.dataset_utils import get_collate_fn_manager_cls

collate_fn = default_collate_fn
collate_fn = get_collate_fn_manager_cls("default")

self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
Expand Down
6 changes: 4 additions & 2 deletions tests/utils/dataset/test_rl_dataset_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def get_gsm8k_data():

def test_rl_dataset():
from verl.utils import hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.dataset.collate_utils import collate_fn
from verl.utils.dataset.rl_dataset import RLHFDataset

tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")
local_path = get_gsm8k_data()
Expand Down Expand Up @@ -68,7 +69,8 @@ def test_rl_dataset():

def test_image_rl_data():
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.dataset.collate_utils import collate_fn
from verl.utils.dataset.rl_dataset import RLHFDataset

tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct")
processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct")
Expand Down
4 changes: 3 additions & 1 deletion tests/workers/rollout/perf/vllm_async_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from verl.protocol import DataProto
from verl.utils import hf_tokenizer
from verl.utils.dataset import RLHFDataset
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
from verl.utils.dataset.dataset_utils import get_collate_fn_manager_cls

default_collate_fn = get_collate_fn_manager_cls("default")


def init_config(n_gpus_per_node) -> DictConfig:
Expand Down
21 changes: 21 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ actor_rollout_ref:
# Minimum number of parameters to trigger wrapping a layer with FSDP
min_num_params: 0

# list of transformer layer classes to wrap with FSDP
transformer_layer_cls_to_wrap: []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This configuration seems not take effect in code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, This is actually supposed to be overridden by the parameter passed in through the script.
image
look at this line:
actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right!


# Model data type for FSDP, default is float32.
model_dtype: float32

# Whether to offload model parameters to CPU (trades speed for memory)
param_offload: false

Expand Down Expand Up @@ -364,6 +370,12 @@ actor_rollout_ref:
# minimum number of params in a wrapped module
min_num_params: 0

# list of transformer layer classes to wrap with FSDP
transformer_layer_cls_to_wrap: []

# Model data type for FSDP, default is float32.
model_dtype: float32

# whether to enable torch.compile
use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}

Expand Down Expand Up @@ -691,6 +703,12 @@ critic:
# Minimum number of parameters to trigger wrapping
min_num_params: 0

# list of transformer layer classes to wrap with FSDP
transformer_layer_cls_to_wrap: []

# Model data type for FSDP, default is float32.
model_dtype: float32

# Number of GPUs in each FSDP shard group; -1 means auto
fsdp_size: -1

Expand Down Expand Up @@ -822,6 +840,9 @@ reward_model:

# Minimum number of parameters to trigger wrapping
min_num_params: 0

# list of transformer layer classes to wrap with FSDP
transformer_layer_cls_to_wrap: []

# Whether to offload model parameters to CPU
param_offload: False
Expand Down
7 changes: 6 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def run(self, config):
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

from verl.utils.dataset.rl_dataset import collate_fn
from verl.utils.dataset.collate_utils import get_collate_fn_manager_cls

if processor:
collate_fn = get_collate_fn_manager_cls(processor.__class__.__name__)
else:
collate_fn = get_collate_fn_manager_cls("default")

# Create training and validation datasets.
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
from verl.utils.dataset.collate_utils import get_collate_fn_manager_cls

collate_fn = default_collate_fn
collate_fn = get_collate_fn_manager_cls("default")

num_workers = self.config.data["dataloader_num_workers"]
if isinstance(train_sampler, AbstractCurriculumSampler):
Expand Down
155 changes: 155 additions & 0 deletions verl/utils/dataset/collate_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Currently when adding a new model, the adaptors are scattered around and many files are touch. I am thinking maybe reorganzing the files need to be changed so that we can have one folder per model. For instance:

verl/models/transformers/llama
verl/models/transformers/qwen2_5_vl
verl/models/transformers/qwen2
verl/models/transformers/[model_name] # model name should be the same as the one in https://github.com/huggingface/transformers/tree/main/src/transformers/models 

And in each model folder, the structure is like below (take mistral3 as the example):

mistral3_collate_utils.py
mistral3_flops_counter.py
mistral3_any_other_change_required.py

what do you think? cc @hiyouga @Fazziekey

BTW I am also not sure if we want to have verl/models/transformers and verl/models/mcore as two folders both containing model specific code. Maybe we should let model related code to be at the level of

verl/models/transformers # common registry utils. No model specific code
verl/models/mcore # common registry utils specfic for mcore. No model specific code

verl/models/llama  
verl/models/qwen2_5_vl 
verl/models/qwen2 
verl/models/[model_name]

@ISEEKYAN what do you think?

Similarly, tests can be standardized:

tests/models/test_llama.py
tests/models/test_[model].py 

With a better code structure it will be easier to write a new model onboarding documentation and let the community add new SOTA models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a unified model unit test like tests/models/test_[model].py is good after the refactor of unified training engine APIs.

For megatron, LLMs of different archs share the same GPTModel API. The efforts of supporting new models will be config mapping and weights mapping and maybe some few patches. VLM would need more definition files since the LLaVaModel's development is slow.
Mbridge/megatron-hub is the official solution of supporting new megatron models, we recommend to obsolete the verl/models/mcore once the code is totally transferred to mbridge and use verl/models/[model_name] for transformers.
And if we need to define a megatron model in verl anyway, the solution is to inherit LLMBridge for that model, like how slime did. And that will be inside the directory verl/models/[model_name]

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import List

import numpy as np
import torch

COLLATE_FN_MANAGER_REGISTRY = {}


def _pad_for_batching(
pixel_values: List[torch.Tensor],
image_sizes: List[List[int]],
):
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`)
image_sizes (`List[List[int]]`):
A list of sizes for each image in `pixel_values` in (height, width) format.
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_shape = (max([size[0] for size in image_sizes]), max([size[1] for size in image_sizes]))
pixel_values = [
torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0])).unsqueeze(0)
for image, size in zip(pixel_values, image_sizes)
]
return pixel_values


def register_collate_fn(name):
"""Decorator to register a reward manager class with a given name.

Args:
name: `(str)`
The name of the reward manager.
"""

def decorator(cls):
if name in COLLATE_FN_MANAGER_REGISTRY and COLLATE_FN_MANAGER_REGISTRY[name] != cls:
raise ValueError(
f"Collate function manager {name} has already been registered: "
f"{COLLATE_FN_MANAGER_REGISTRY[name]} vs {cls}"
)
COLLATE_FN_MANAGER_REGISTRY[name] = cls
return cls

return decorator


def get_collate_fn_manager_cls(name):
"""Get the collate function manager class with a given name.

Args:
name: `(str)`
The name of the collate function manager.

Returns:
`(type)`: The collate function manager class.
"""
if name not in COLLATE_FN_MANAGER_REGISTRY:
default_collate_fn = COLLATE_FN_MANAGER_REGISTRY.get("default", None)
if default_collate_fn is None:
raise ValueError(f"Unknown collate function manager: {name}")
return default_collate_fn

return COLLATE_FN_MANAGER_REGISTRY[name]


@register_collate_fn("default")
def collate_fn(data_list: List[dict]) -> dict:
"""
Collate a batch of sample dicts into batched tensors and arrays.

Args:
data_list: List of dicts mapping feature names to torch.Tensor or other values.

Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
(batch_size, *dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
tensors = defaultdict(list)
non_tensors = defaultdict(list)

for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
else:
non_tensors[key].append(val)

for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)

for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)

return {**tensors, **non_tensors}


@register_collate_fn("PixtralProcessor")
def collate_fn_for_pixtral(data_list: List[dict]) -> dict:
"""
Collate a batch of sample dicts into batched tensors and arrays.

Args:
data_list: List of dicts mapping feature names to torch.Tensor or other values.

Returns:
Dict where tensor entries are stacked into a torch.Tensor of shape
(batch_size, *dims) and non-tensor entries are converted to
np.ndarray of dtype object with shape (batch_size,).
"""
tensors = defaultdict(list)
non_tensors = defaultdict(list)

for data in data_list:
for key, val in data.items():
if isinstance(val, torch.Tensor):
tensors[key].append(val)
else:
non_tensors[key].append(val)

for key, val in tensors.items():
tensors[key] = torch.stack(val, dim=0)

for key, val in non_tensors.items():
if key == "multi_modal_inputs":
val = [ele for ele in val if ele]
if not val:
continue
pixel_values = [v["pixel_values"][0] for v in val]
image_sizes = [v["image_sizes"][0] for v in val]
pixel_values = _pad_for_batching(pixel_values, image_sizes)
for v, pixel_value in zip(val, pixel_values):
v["pixel_values"] = pixel_value
non_tensors[key] = np.array(val, dtype=object)

return {**tensors, **non_tensors}
Loading
Loading