diff --git a/examples/sft/gsm8k/run_gemma3.sh b/examples/sft/gsm8k/run_gemma3.sh new file mode 100644 index 00000000000..9b4a5f4bc87 --- /dev/null +++ b/examples/sft/gsm8k/run_gemma3.sh @@ -0,0 +1,30 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: $0 [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=google/gemma-3-4b-it \ + model.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap=[Gemma3DecoderLayer] \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft-gemma \ + trainer.experiment_name=google/gemma-3-4b-it \ + trainer.total_epochs=4 \ + trainer.logger=['console','wandb'] \ + trainer.default_hdfs_dir=null $@ \ No newline at end of file diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py index a3042cabcc4..204f0f78241 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py @@ -350,6 +350,7 @@ def _process_parameter_names(name): "Phi3ForCausalLM": llama_dtensor_weight_loader, "GemmaForCausalLM": gemma_dtensor_weight_loader, "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "Gemma3ForCausalLM": gemma_dtensor_weight_loader, "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index a40699a745b..29bf824100b 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -43,4 +43,9 @@ actor: strategy: fsdp # This is for backward-compatibility ulysses_sequence_parallel_size: 1 # sp size fsdp_config: - fsdp_size: -1 \ No newline at end of file + fsdp_size: -1 + +multi_model: + fsdp_config: + wrap_policy: + transformer_layer_cls_to_wrap: null \ No newline at end of file diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 6d2b2dbe437..f5923c3017e 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -52,7 +52,7 @@ actor_rollout_ref: shuffle: False ulysses_sequence_parallel_size: 1 # sp size checkpoint: - contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + contents: ['model', 'optimizer', 'extra', 'hf_model'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space optim: lr: 1e-6 lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. @@ -63,7 +63,7 @@ actor_rollout_ref: weight_decay: 0.01 fsdp_config: wrap_policy: - # transformer_layer_cls_to_wrap: None + transformer_layer_cls_to_wrap: null min_num_params: 0 param_offload: False optimizer_offload: False @@ -73,7 +73,7 @@ actor_rollout_ref: fsdp_config: param_offload: False wrap_policy: - # transformer_layer_cls_to_wrap: None + transformer_layer_cls_to_wrap: null min_num_params: 0 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null @@ -140,7 +140,7 @@ critic: param_offload: False optimizer_offload: False wrap_policy: - # transformer_layer_cls_to_wrap: None + transformer_layer_cls_to_wrap: null min_num_params: 0 fsdp_size: -1 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index 86e10ccd01d..2b8407619a8 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -25,6 +25,7 @@ model: fsdp_config: wrap_policy: min_num_params: 0 + transformer_layer_cls_to_wrap: null cpu_offload: False offload_params: False external_lib: null @@ -52,4 +53,5 @@ trainer: total_training_steps: null logger: ['console'] seed: 1 +attn_implementation: flash_attention_2 diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 2efdd69eae8..e0234601c0b 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -197,19 +197,38 @@ def _build_model_optimizer(self): trust_remote_code = self.config.model.trust_remote_code # load config first - config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + + # For Gemma3 models, we need to use the text config + if "gemma3" in str(type(model_config)).lower(): + from transformers import Gemma3ForCausalLM, Gemma3TextConfig + model_config = Gemma3TextConfig.from_pretrained(local_model_path) + model_config.architectures = ["Gemma3ForCausalLM"] + if self.config.ulysses_sequence_parallel_size > 1: assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" # This may be very large - init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, + init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh) with init_context(): - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, - config=config, + # Use Gemma3ForCausalLM specifically if it's a Gemma model + if "gemma3" in str(type(model_config)).lower(): + if self.device_mesh.get_rank() == 0: + print("Using Gemma3ForCausalLM model class") + + from transformers import Gemma3ForCausalLM + self.model: PreTrainedModel = Gemma3ForCausalLM.from_pretrained(local_model_path, + config=model_config, + torch_dtype=torch.float32, + attn_implementation=self.config.attn_implementation, + trust_remote_code=trust_remote_code) + else: + self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, + config=model_config, torch_dtype=torch.float32, - attn_implementation='flash_attention_2', + attn_implementation=self.config.attn_implementation, trust_remote_code=trust_remote_code) if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 5901d35451f..c4649f8f9ca 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -389,6 +389,10 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() + if "multi_model" in self.config: + # this is for model like gemma3 which has both text and vision. + fsdp_config = self.config.multi_model.fsdp_config + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( model_path=self.config.model.path, fsdp_config=fsdp_config,