Skip to content
/ lmm-r1 Public

Extend OpenRLHF to support LMM RL training for reproduction of DeepSeek-R1 on multimodal tasks.

License

Notifications You must be signed in to change notification settings

TideDra/lmm-r1

Repository files navigation

LMM-R1 logo

GitHub Contributors Issues Issues GitHub pull requests GitHub stars
Open-source / Comprehensive / Lightweight / Easy-to-use


LMM-R1 is a fork of OpenRLHF, aimed at providing high-performance LMM Reinforcement Learning infrastructure for reproduction of DeepSeek-R1 on multimodal tasks.

We currently support PPO/REINFORCE++/RLOO training for LMM, and achieve 4.7x speedup (RLOO) compared with R1-V (GRPO).

time_compare

Team:

Gongrui Zhang | Yingzhe Peng | Miaosen Zhang | Chengzhi Yu | Qipeng Zhu

News

  • [2025/2/13] We release code of LMM-R1!

Our Findings

Super cross-modal generation ability of rule-based RL

We train Qwen2.5-VL-3B-Instruct on 8k text-only MATH (level3-5) dataset using RLOO with a rule-based reward function. We find it gains significant improvement on challenging multi-modal math benchmarks (MathVision, MathVerse, Olympiadbench_en).

MathVision MathVerse Olympiadbench_en
Qwen2.5-VL-3B-Instruct 23.09 27.99 10.15
Qwen2.5-VL-3B-Instruct-rloo-math 27.47 35.1 13.32

wandblog1

This result reminds us that the existing rich high-quaity text-modality reasoning data maybe beneficial to train a strong multimodal reasoning model, especially at this very moment when high-quality multimodal reasoning data is scarce.

We provide the data examples/data/mathlv345_8k_chatml.json and script examples/scripts/r1_scripts/train_rloo_qwenvl2_5_math.sh for reproduction. Note that the system prompt of our model is consistent with that of training during evaluation.

More findings are coming...

Features

  • Support LMM training (Qwen2-VL, Qwen2.5-VL).
  • Distributed PPO and REINFORCE++/RLOO implementations based on Ray.
  • Ray-based Reinforced Finetuning
  • Support Ray-based PPO and REINFORCE++/RLOO using Hybrid Engine (--colocate_all_models, --vllm_enable_sleep and --vllm_gpu_memory_utilization 0.5)
  • Full RLHF fine-tuning support for models with over 70 billion parameters.
  • Integration with vLLM for accelerated generation in RLHF tasks (--vllm_num_engines).
  • Support for multiple reward models (--reward_pretrain model1,model2...) and remote reward models (--remote_rm_url).
  • Integration of FlashAttention2 (--flash_attn).
  • Support for QLoRA (--load_in_4bit) and LoRA (--lora_rank, --target_modules).
  • Compatibility with HuggingFace's tokenizer.apply_chat_template for datasets (--apply_chat_template and --input_key).
  • Logging support with Wandb (--use_wandb) and TensorBoard (--use_tensorboard).
  • Checkpoint recovery functionality (--load_checkpoint and --save_steps).
  • Provided multi-node training scripts, such as Ray PPO.

Quick Start

Installation

git clone https://github.com/TideDra/lmm-r1.git
cd lmm-r1
pip install -e .[vllm]
pip install flash_attn --no-build-isolation

Note

We recommend using vLLM 0.7.2 or higher. We also provided the Dockerfiles for vLLM and One-Click Installation Script of Nvidia-Docker.

Prepare Datasets

LMM-R1 requires the multimodal prompt dataset to be in OpenAI-compatible message format:

[
  {
    "message":"[
      {
        \"role\": \"user\",
        \"content\": [
            { \
                \"type\": \"image\",
                \"image\": \"file:///path/to/your/image.jpg\",
            }, \
            {\"type\": \"text\", \"text\": \"How many cats in the image?\"},
        ],
      }
    ]",
    "answer": "$3$"
  },
]

Note that message is a stringfied list. An example dataset examples/data/test_message.jsonl is for reference.

  • We can use --input_key to specify the JSON key name of the input datasets --prompt_data {name or path} (PPO) or --dataset {name or path}. Do not use --apply_chat_template for multimodal prompt, the message will be processed internally.
  • OpenRLHF also support mixing multiple datasets using --prompt_data_probs 0.1,0.4,0.5 (PPO) or --dataset_probs 0.1,0.4,0.5.

How to specify training and test datasets ?

You can specify it using the data_type@data_dir format. For example, the dataset can be set as --dataset json@./data.

data
├── test.jsonl
└── train.jsonl

Note

By default, we use train and test as splits to distinguish training and testing datasets from Huggingface. The JSON key options depends on the specific datasets. See Reward Dataset and SFT Dataset

LMM RLOO with Ray

Note

Set --train-vlm for LMM training.

DATASET="test_message.jsonl"
MODEL_CPK_NAME="qwenvl25_3B_ins_rloo_mathvision"
PRETRAIN_MODEL="Qwen/Qwen2.5-VL-3B-Instruct"
SAVE_PATH="/ckpts"
mkdir -p "${SAVE_PATH}/${MODEL_CPK_NAME}"

# deploy remote reward function at 127.0.0.1:5000
python -m openrlhf.models.remote_rm.math_verifier --dataset $DATASET --input_key message --prompt-template chatml > "${SAVE_PATH}/${MODEL_CPK_NAME}/remote_rm.log" 2>&1 &
childpid=$!

ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json='{"working_dir": "/root/projects/OpenRLHF"}' \
   -- python3 -m openrlhf.cli.train_ppo_ray \
   --ref_num_nodes 1 \
   --ref_num_gpus_per_node 8 \
   --remote_rm_url http://127.0.0.1:5000/get_reward \
   --actor_num_nodes 1 \
   --actor_num_gpus_per_node 8 \
   --vllm_num_engines 8 \
   --vllm_tensor_parallel_size 1 \
   --colocate_all_models \
   --vllm_enable_sleep \
   --vllm_gpu_memory_utilization 0.7 \
   --vllm_sync_backend gloo \
   --enable_prefix_caching \
   --pretrain $PRETRAIN_MODEL \
   --save_path $SAVE_PATH/$MODEL_CPK_NAME \
   --micro_train_batch_size 2 \
   --train_batch_size 128 \
   --micro_rollout_batch_size 4 \
   --rollout_batch_size 256 \
   --temperature 1 \
   --n_samples_per_prompt 16 \
   --max_epochs 1 \
   --num_episodes 30 \
   --prompt_max_len 4096 \
   --max_samples 100000 \
   --generate_max_len 4096 \
   --advantage_estimator rloo \
   --zero_stage 3 \
   --bf16 \
   --actor_learning_rate 1e-6 \
   --init_kl_coef 0.01 \
   --prompt_data $DATASET \
   --input_key message \
   --normalize_reward \
   --flash_attn \
   --gradient_checkpointing \
   --save_steps 10 \
   --ckpt_path $SAVE_PATH/$MODEL_CPK_NAME/ckpt \
   --save_hf_ckpt \
   --use_tensorboard $SAVE_PATH/$MODEL_CPK_NAME/logs \
   --train_vlm

ray stop
kill $childpid

Note

Not set --vllm_num_engines means not using the vLLM engine.

Note

RLOO and REINFORCE++-baseline in OPENRLHF are a modification based on REINFORCE++:

  • REINFORCE++ integrates key optimization techniques from PPO while eliminating the need for a critic network.
  • REINFORCE++-baseline uses the mean reward of multiple samples from the same prompt as the baseline.
  • RLOO in OpenRLHF modifies the original version by incorporating the per-token KL reward and utilizing the PPO-clip loss.

Note

If you encounter an error related to index out of range when deepspeed sets up the GPU devices, you can try to set the environment variable RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES as a workaround.

# For NVIDIA GPUs:
export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1

Performance Tuning Guide

To achieve optimal performance, we recommend allocating nodes number vLLM:Actor:Critic = 1:1:1. For example, for a 70B model with 48 A100 GPUs, it is advised to allocate 16 A100 GPUs to the vLLM Engine, 16 GPUs to the Actor model, and the remaining 16 GPUs to the Critic model. Additionally, enable the --colocate_critic_reward, --colocate_actor_ref options to merge nodes. Finally, you should increase the rollout_micro_batch_size (and minimize the TP size of vLLM engine) as much as possible. During the training phase, a larger --micro_train_batch_size is better and enable --packing_samples (Not supported for LMM). When there are enough GPUs, please disable --adam_offload and enable --overlap_comm. For multi-nodes RLHF, please use --vllm_sync_backend nccl with vLLM 0.7.2+. Enable enable_prefix_caching in vLLM generation when n_samples_per_prompts > 1. Using hybrid engine --colocate_all_models and –vllm_enable_sleeprather than distributed RLHF when the model size and context length are small values.

Starchart

Star History Chart

References & Acknowledgements

We sincerely thank DeepSeek for their exploration on LLM reasoning, and OpenRLHF for their incredible RL infrastructure. We also thank open-r1 and simpleRL-reason which give us insights on reproduction of R1. Special thanks to Kai Yang, Jie Liu, ZhiYuan You for their valuable suggestions, and the Big Data Computing Center of Southeast University for the hardware support.

Citation

@misc{peng2025lmmr1,
  author       = {YingZhe Peng and Gongrui Zhang and Xin Geng and Xu Yang},
  title        = {LMM-R1},
  howpublished = {\url{https://github.com/TideDra/lmm-r1}},
  note         = {Accessed: 2025-02-13},
  year         = {2025}
}

About

Extend OpenRLHF to support LMM RL training for reproduction of DeepSeek-R1 on multimodal tasks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages