Skip to content

add support for gemma3#1123

Open
rasoolfa wants to merge 2 commits intoverl-project:mainfrom
rasoolfa:support-gemma3
Open

add support for gemma3#1123
rasoolfa wants to merge 2 commits intoverl-project:mainfrom
rasoolfa:support-gemma3

Conversation

@rasoolfa
Copy link

No description provided.

@CLAassistant
Copy link

CLAassistant commented Apr 16, 2025

CLA assistant check
All committers have signed the CLA.

@rasoolfa rasoolfa mentioned this pull request Apr 16, 2025
Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx for contributing!

set -x

if [ "$#" -lt 2 ]; then
echo "Usage: run_deepseek_6b7.sh <nproc_per_node> <save_path> [other_configs...]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please replace the msg with $0 so it's less confusing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra','hf_model'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add space after comma

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)

# For Gemma3 models, we need to use the text config
if "gemma3" in str(type(model_config)).lower():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using model name in string looks a bit adhoc. @hiyouga do you know if there're better ways to make the code more general to other multi-modal llms?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will add that later

data.prompt_key=extra_info \
data.response_key=extra_info \
data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind also providing a reference training log to https://verl.readthedocs.io/en/latest/experiment/ppo.html (ppo.rst)? It's not required but it helps the community to have more reference baselines

@zihaolucky
Copy link

Do we need change RL related components to support gemma3 for PPO training?

  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/ray/base.py", line 419, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 415, in init_model
    self.rollout, self.rollout_sharding_manager = self._build_rollout()
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 321, in _build_rollout
    rollout = vLLMRollout(model_path=local_path,
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py", line 97, in __init__
    assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
  File "/usr/local/lib/python3.10/dist-packages/transformers/configuration_utils.py", line 214, in __getattribute__
    return super().__getattribute__(key)
AttributeError: 'Gemma3Config' object has no attribute 'max_position_embeddings'
(TaskRunner pid=6477) Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::WorkerDict.actor_rollout_init_model() (pid=7264, ip=10.154.32.60, actor_id=875679a5ddf0b95c1278b23401000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fb7e79058a0>)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/ray/base.py", line 419, in func
(TaskRunner pid=6477)     return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/base/decorator.py", line 404, in inner
(TaskRunner pid=6477)     return func(*args, **kwargs)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 415, in init_model
(TaskRunner pid=6477)     self.rollout, self.rollout_sharding_manager = self._build_rollout()
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 321, in _build_rollout
(TaskRunner pid=6477)     rollout = vLLMRollout(model_path=local_path,
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py", line 97, in __init__
(TaskRunner pid=6477)     assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
(TaskRunner pid=6477)   File "/usr/local/lib/python3.10/dist-packages/transformers/configuration_utils.py", line 214, in __getattribute__
(TaskRunner pid=6477)     return super().__getattribute__(key)
(TaskRunner pid=6477) AttributeError: 'Gemma3Config' object has no attribute 'max_position_embeddings'

@ZSL98
Copy link
Collaborator

ZSL98 commented Apr 17, 2025

Do we need change RL related components to support gemma3 for PPO training?

Yes, I think this PR only makes the sft training of gemma3 work..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to maintain or modify this file anymore. It is only archived for vllm before version 0.7.

@rasoolfa
Copy link
Author

rasoolfa commented Apr 17, 2025

Do we need change RL related components to support gemma3 for PPO training?

Yes, I think this PR only makes the sft training of gemma3 work..

The solution in my PR assumes that you're starting from an SFT checkpoint where Gemma3ForCausalLM was used during the SFT step. Then, the RL can be initialized from that checkpoint. I've tested the code for SFT, generating samples with SFT checkpoints, and RL training—all of which seem working without any issues.

Also, I explain it here that starting RL with gemma3 would require more changes.

@rasoolfa
Copy link
Author

Do we need change RL related components to support gemma3 for PPO training?

  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/ray/base.py", line 419, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 415, in init_model
    self.rollout, self.rollout_sharding_manager = self._build_rollout()
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 321, in _build_rollout
    rollout = vLLMRollout(model_path=local_path,
  File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py", line 97, in __init__
    assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
  File "/usr/local/lib/python3.10/dist-packages/transformers/configuration_utils.py", line 214, in __getattribute__
    return super().__getattribute__(key)
AttributeError: 'Gemma3Config' object has no attribute 'max_position_embeddings'
(TaskRunner pid=6477) Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::WorkerDict.actor_rollout_init_model() (pid=7264, ip=10.154.32.60, actor_id=875679a5ddf0b95c1278b23401000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fb7e79058a0>)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/ray/base.py", line 419, in func
(TaskRunner pid=6477)     return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/single_controller/base/decorator.py", line 404, in inner
(TaskRunner pid=6477)     return func(*args, **kwargs)
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 415, in init_model
(TaskRunner pid=6477)     self.rollout, self.rollout_sharding_manager = self._build_rollout()
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/fsdp_workers.py", line 321, in _build_rollout
(TaskRunner pid=6477)     rollout = vLLMRollout(model_path=local_path,
(TaskRunner pid=6477)   File "/data/pe-task/ai_storage_rag_search_with_rl/ReSearch-main/src/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py", line 97, in __init__
(TaskRunner pid=6477)     assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \
(TaskRunner pid=6477)   File "/usr/local/lib/python3.10/dist-packages/transformers/configuration_utils.py", line 214, in __getattribute__
(TaskRunner pid=6477)     return super().__getattribute__(key)
(TaskRunner pid=6477) AttributeError: 'Gemma3Config' object has no attribute 'max_position_embeddings'

are you getting this error with the version of the code in this pr?

@zihaolucky
Copy link

zihaolucky commented Apr 20, 2025

are you getting this error with the version of the code in this pr?

@rasoolfa I'm using 0.3.0post1 with vllm 0.8.2, after manually apply these changes in your PR, and fixed these param check, gemma3 in gsm8k can run. But I found the training will hang, I'm still trying to figure it out.

@rasoolfa
Copy link
Author

rasoolfa commented Apr 20, 2025

are you getting this error with the version of the code in this pr?

@rasoolfa I'm using 0.3.0post1 with vllm 0.8.2, after manually apply these changes in your PR, and fixed these param check, gemma3 in gsm8k can run. But I found the training will hang, I'm still trying to figure it out.

I see. Can you post your steps in details here to see if I can reproduce it? I haven't seen it on my side. It is usually beneficial to post your solutions ( and detail steps) if you find one.

@Raf-Chen
Copy link
Contributor

are you getting this error with the version of the code in this pr?

@rasoolfa I'm using 0.3.0post1 with vllm 0.8.2, after manually apply these changes in your PR, and fixed these param check, gemma3 in gsm8k can run. But I found the training will hang, I'm still trying to figure it out.

Hi! Could u please share your solution? I tried to load gemma like this, but I can not get valid rollout.

@zihaolucky
Copy link

are you getting this error with the version of the code in this pr?

@rasoolfa I'm using 0.3.0post1 with vllm 0.8.2, after manually apply these changes in your PR, and fixed these param check, gemma3 in gsm8k can run. But I found the training will hang, I'm still trying to figure it out.

Hi! Could u please share your solution? I tried to load gemma like this, but I can not get valid rollout.

@Raf-Chen Hi, I think you can use tests/generation/run_gemma2.sh (change it to gemma3) first. I use verl==0.3.0.post1, when I use the latest master code, it generates wired outputs.

@tryumanshow
Copy link

@rasoolfa

Thank you so much for your contribution. I am tracking your works since you contributed to this PR.

I am suffering exactly the same problem that @zihaolucky mentioned.

Suppose I am using 2 GPUs and run the script shown below:

#!/bin/bash

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=[SFT Trained Model] \
    actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap=[Gemma3DecoderLayer] \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=8 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='test' \
    trainer.n_gpus_per_node=2 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=5 \
    trainer.total_epochs=15 $@

and you will see this kind of logs on a terminal:
image

How can I solve the problem?

@rasoolfa
Copy link
Author

this should address your issue as I explained it here: #1013 (comment)

@jacklanda
Copy link
Contributor

Any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants