Skip to content

[BREAKING][model, data] feat: add support for Mistral3#2338

Open
diqiuzhuanzhuan wants to merge 7 commits intoverl-project:mainfrom
diqiuzhuanzhuan:feature/support-mistral3
Open

[BREAKING][model, data] feat: add support for Mistral3#2338
diqiuzhuanzhuan wants to merge 7 commits intoverl-project:mainfrom
diqiuzhuanzhuan:feature/support-mistral3

Conversation

@diqiuzhuanzhuan
Copy link
Contributor

What does this PR do?

This PR update adds support for Mistral3. Additionally, this update introduces a registry-based mechanism for managing dataset collate functions. It enables flexible selection of batch collation logic by name, supporting both the default and PixtralProcessor-specific collation.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_c ontroller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

image image

wandb log: https://wandb.ai/diqiuzhuanzhuan/verl_grpo_example_geo3k/runs/84fvuv9s

API and Usage Example

Usage Example (One node with 8 A800/A100 GPUs)

set -x
ENGINE=${1:-vllm}
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
echo $HOME
train_files=$HOME/data/geo3k/train.parquet
test_files=$HOME/data/geo3k/test.parquet
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=${train_files} \
    data.val_files=${test_files} \
    data.train_batch_size=256 \
    data.max_prompt_length=8196 \
    data.max_response_length=2048 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.image_key=images \
    actor_rollout_ref.model.path=mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.01 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=True \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
    actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \
    actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=5 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.ref.strategy=fsdp2 \
    actor_rollout_ref.actor.strategy=fsdp2 \
    actor_rollout_ref.ref.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \
    actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    actor_rollout_ref.rollout.name=$ENGINE \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=5 \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_geo3k' \
    trainer.experiment_name='mistral3_1_24b_ai' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

Add 'get_collate_fn_manager_cls' in verl/utils/dataset/dataset_utils.py and get specific 'collate_fn' via different models.

# Add code snippet or script demonstrating how to use this
from verl.utils.dataset.dataset_utils import get_collate_fn_manager_cls
# for non-Mistral models
collate_fn = get_collate_fn_manager_cls('default')
# Mistral 3 family models use the PixtralProcessor.
collate_fn = get_collate_fn_manager_cls('PixtralProcessor')

In main_ppo.py,

from verl.utils.dataset.dataset_utils import get_collate_fn_manager_cls
if processor:
    collate_fn = get_collate_fn_manager_cls(processor.__class__.__name__)
else:
    collate_fn = get_collate_fn_manager_cls("default")

High-Level Design

Design and Purpose of get_collate_fn_manager_cls

The get_collate_fn_manager_cls function, together with the register_collate_fn decorator and the COLLATE_FN_MANAGER_REGISTRY dictionary, provides a flexible and extensible mechanism for managing and retrieving collate functions in the dataset pipeline.

Key Points:

  1. Registration Mechanism
    The register_collate_fn decorator allows different collate functions to be registered under a unique string key (e.g., "default", "PixtralProcessor"). This enables easy extension and modularization of batch collation logic for different data formats or model requirements.

  2. Centralized Retrieval
    The get_collate_fn_manager_cls(name) function serves as a unified interface to retrieve the appropriate collate function by name. If the requested name is not registered, it falls back to the "default" collate function, ensuring robustness.

  3. Extensibility
    New collate functions can be added simply by defining them and registering with the decorator, without modifying the main data pipeline logic. This design supports future expansion for new models or data modalities.

  4. Maintainability
    By decoupling collate function registration and retrieval, the codebase becomes easier to maintain and reason about, especially as the number of supported data types grows.

Example Usage

@register_collate_fn("MyCustomProcessor")
def my_custom_collate_fn(data_list):
    # custom collation logic
    ...

collate_fn = get_collate_fn_manager_cls("MyCustomProcessor")
batch = collate_fn(data_list)

Design Summary

This design introduces a registry-based mechanism for managing dataset collate functions. The register_collate_fn decorator allows for easy registration of new collate functions, while get_collate_fn_manager_cls provides a unified interface for retrieval. This design improves flexibility, extensibility, and maintainability of the data collation pipeline, making it straightforward to support multiple data formats and model requirements.

Specific Changes

  1. Add run_mistral3_1_24b.sh to examples/grpo_trainer
  2. Add the file verl/utils/dataset/dataset_utils.py, which determines the appropriate collate_fn to use via get_collate_fn_manager_cls.
  3. Add two options to the FSDP config in ppo_trainer.yaml
    image
    Add 'transformer_layer_cls_to_wrap' to fsdp_config.wrap_policy
    Add 'model_dtype' to fsdp_config.
  4. Modified several parts of the codebase to adapt to the updated collate_fn interface.
  5. Add _estimate_mistral3_1_flops function to verl/utils/flops_counter.py

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

min_num_params: 0

# list of transformer layer classes to wrap with FSDP
transformer_layer_cls_to_wrap: []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This configuration seems not take effect in code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, This is actually supposed to be overridden by the parameter passed in through the script.
image
look at this line:
actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right!

@eric-haibin-lin eric-haibin-lin self-assigned this Jul 4, 2025
@@ -0,0 +1,155 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Currently when adding a new model, the adaptors are scattered around and many files are touch. I am thinking maybe reorganzing the files need to be changed so that we can have one folder per model. For instance:

verl/models/transformers/llama
verl/models/transformers/qwen2_5_vl
verl/models/transformers/qwen2
verl/models/transformers/[model_name] # model name should be the same as the one in https://github.com/huggingface/transformers/tree/main/src/transformers/models 

And in each model folder, the structure is like below (take mistral3 as the example):

mistral3_collate_utils.py
mistral3_flops_counter.py
mistral3_any_other_change_required.py

what do you think? cc @hiyouga @Fazziekey

BTW I am also not sure if we want to have verl/models/transformers and verl/models/mcore as two folders both containing model specific code. Maybe we should let model related code to be at the level of

verl/models/transformers # common registry utils. No model specific code
verl/models/mcore # common registry utils specfic for mcore. No model specific code

verl/models/llama  
verl/models/qwen2_5_vl 
verl/models/qwen2 
verl/models/[model_name]

@ISEEKYAN what do you think?

Similarly, tests can be standardized:

tests/models/test_llama.py
tests/models/test_[model].py 

With a better code structure it will be easier to write a new model onboarding documentation and let the community add new SOTA models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a unified model unit test like tests/models/test_[model].py is good after the refactor of unified training engine APIs.

For megatron, LLMs of different archs share the same GPTModel API. The efforts of supporting new models will be config mapping and weights mapping and maybe some few patches. VLM would need more definition files since the LLaVaModel's development is slow.
Mbridge/megatron-hub is the official solution of supporting new megatron models, we recommend to obsolete the verl/models/mcore once the code is totally transferred to mbridge and use verl/models/[model_name] for transformers.
And if we need to define a megatron model in verl anyway, the solution is to inherit LLMBridge for that model, like how slime did. And that will be inside the directory verl/models/[model_name]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants