Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/sft/gsm8k/run_gemma3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
set -x

if [ "$#" -lt 2 ]; then
echo "Usage: $0 <nproc_per_node> <save_path> [other_configs...]"
exit 1
fi

nproc_per_node=$1
save_path=$2

# Shift the arguments so $@ refers to the rest
shift 2

torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind also providing a reference training log to https://verl.readthedocs.io/en/latest/experiment/ppo.html (ppo.rst)? It's not required but it helps the community to have more reference baselines

data.micro_batch_size_per_gpu=4 \
model.partial_pretrain=google/gemma-3-4b-it \
model.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap=[Gemma3DecoderLayer] \
trainer.default_local_dir=$save_path \
trainer.project_name=gsm8k-sft-gemma \
trainer.experiment_name=google/gemma-3-4b-it \
trainer.total_epochs=4 \
trainer.logger=['console','wandb'] \
trainer.default_hdfs_dir=null $@
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _process_parameter_names(name):
"Phi3ForCausalLM": llama_dtensor_weight_loader,
"GemmaForCausalLM": gemma_dtensor_weight_loader,
"Gemma2ForCausalLM": gemma_dtensor_weight_loader,
"Gemma3ForCausalLM": gemma_dtensor_weight_loader,
"GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights,
"Starcoder2ForCausalLM": starcoder2_dtensor_load_weights,
"Qwen2ForCausalLM": qwen2_dtensor_weight_loader,
Expand Down
7 changes: 6 additions & 1 deletion verl/trainer/config/generation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@ actor:
strategy: fsdp # This is for backward-compatibility
ulysses_sequence_parallel_size: 1 # sp size
fsdp_config:
fsdp_size: -1
fsdp_size: -1

multi_model:
fsdp_config:
wrap_policy:
transformer_layer_cls_to_wrap: null
8 changes: 4 additions & 4 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ actor_rollout_ref:
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
contents: ['model', 'optimizer', 'extra', 'hf_model'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
Expand All @@ -63,7 +63,7 @@ actor_rollout_ref:
weight_decay: 0.01
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
transformer_layer_cls_to_wrap: null
min_num_params: 0
param_offload: False
optimizer_offload: False
Expand All @@ -73,7 +73,7 @@ actor_rollout_ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
transformer_layer_cls_to_wrap: null
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
Expand Down Expand Up @@ -140,7 +140,7 @@ critic:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
transformer_layer_cls_to_wrap: null
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ model:
fsdp_config:
wrap_policy:
min_num_params: 0
transformer_layer_cls_to_wrap: null
cpu_offload: False
offload_params: False
external_lib: null
Expand Down Expand Up @@ -52,4 +53,5 @@ trainer:
total_training_steps: null
logger: ['console']
seed: 1
attn_implementation: flash_attention_2

29 changes: 24 additions & 5 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,38 @@ def _build_model_optimizer(self):

trust_remote_code = self.config.model.trust_remote_code
# load config first
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)

# For Gemma3 models, we need to use the text config
if "gemma3" in str(type(model_config)).lower():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using model name in string looks a bit adhoc. @hiyouga do you know if there're better ways to make the code more general to other multi-modal llms?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will add that later

from transformers import Gemma3ForCausalLM, Gemma3TextConfig
model_config = Gemma3TextConfig.from_pretrained(local_model_path)
model_config.architectures = ["Gemma3ForCausalLM"]

if self.config.ulysses_sequence_parallel_size > 1:
assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"

# This may be very large
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings,
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings,
mesh=self.device_mesh)

with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=config,
# Use Gemma3ForCausalLM specifically if it's a Gemma model
if "gemma3" in str(type(model_config)).lower():
if self.device_mesh.get_rank() == 0:
print("Using Gemma3ForCausalLM model class")

from transformers import Gemma3ForCausalLM
self.model: PreTrainedModel = Gemma3ForCausalLM.from_pretrained(local_model_path,
config=model_config,
torch_dtype=torch.float32,
attn_implementation=self.config.attn_implementation,
trust_remote_code=trust_remote_code)
else:
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=model_config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2',
attn_implementation=self.config.attn_implementation,
trust_remote_code=trust_remote_code)

if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
Expand Down
4 changes: 4 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def init_model(self):
else:
optim_config = None
fsdp_config = OmegaConf.create()
if "multi_model" in self.config:
# this is for model like gemma3 which has both text and vision.
fsdp_config = self.config.multi_model.fsdp_config

self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
Expand Down