Conversation
eric-haibin-lin
left a comment
There was a problem hiding this comment.
thx for contributing!
examples/sft/gsm8k/run_gemma3.sh
Outdated
| set -x | ||
|
|
||
| if [ "$#" -lt 2 ]; then | ||
| echo "Usage: run_deepseek_6b7.sh <nproc_per_node> <save_path> [other_configs...]" |
There was a problem hiding this comment.
please replace the msg with $0 so it's less confusing
verl/trainer/config/ppo_trainer.yaml
Outdated
| ulysses_sequence_parallel_size: 1 # sp size | ||
| checkpoint: | ||
| contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space | ||
| contents: ['model', 'optimizer', 'extra','hf_model'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space |
There was a problem hiding this comment.
nit: add space after comma
| model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) | ||
|
|
||
| # For Gemma3 models, we need to use the text config | ||
| if "gemma3" in str(type(model_config)).lower(): |
There was a problem hiding this comment.
using model name in string looks a bit adhoc. @hiyouga do you know if there're better ways to make the code more general to other multi-modal llms?
| data.prompt_key=extra_info \ | ||
| data.response_key=extra_info \ | ||
| data.prompt_dict_keys=['question'] \ | ||
| +data.response_dict_keys=['answer'] \ |
There was a problem hiding this comment.
would you mind also providing a reference training log to https://verl.readthedocs.io/en/latest/experiment/ppo.html (ppo.rst)? It's not required but it helps the community to have more reference baselines
|
Do we need change RL related components to support gemma3 for PPO training? |
Yes, I think this PR only makes the sft training of gemma3 work.. |
There was a problem hiding this comment.
I think there is no need to maintain or modify this file anymore. It is only archived for vllm before version 0.7.
The solution in my PR assumes that you're starting from an SFT checkpoint where Gemma3ForCausalLM was used during the SFT step. Then, the RL can be initialized from that checkpoint. I've tested the code for SFT, generating samples with SFT checkpoints, and RL training—all of which seem working without any issues. Also, I explain it here that starting RL with gemma3 would require more changes. |
are you getting this error with the version of the code in this pr? |
@rasoolfa I'm using 0.3.0post1 with vllm 0.8.2, after manually apply these changes in your PR, and fixed these param check, gemma3 in gsm8k can run. But I found the training will hang, I'm still trying to figure it out. |
I see. Can you post your steps in details here to see if I can reproduce it? I haven't seen it on my side. It is usually beneficial to post your solutions ( and detail steps) if you find one. |
Hi! Could u please share your solution? I tried to load gemma like this, but I can not get valid rollout. |
@Raf-Chen Hi, I think you can use tests/generation/run_gemma2.sh (change it to gemma3) first. I use verl==0.3.0.post1, when I use the latest master code, it generates wired outputs. |
|
Thank you so much for your contribution. I am tracking your works since you contributed to this PR. I am suffering exactly the same problem that @zihaolucky mentioned. Suppose I am using 2 GPUs and run the script shown below: #!/bin/bash
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=[SFT Trained Model] \
actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap=[Gemma3DecoderLayer] \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='test' \
trainer.n_gpus_per_node=2 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@and you will see this kind of logs on a terminal: How can I solve the problem? |
|
this should address your issue as I explained it here: #1013 (comment) |
|
Any update? |

No description provided.