-
Notifications
You must be signed in to change notification settings - Fork 3.7k
add support for gemma3 #1123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add support for gemma3 #1123
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| set -x | ||
|
|
||
| if [ "$#" -lt 2 ]; then | ||
| echo "Usage: $0 <nproc_per_node> <save_path> [other_configs...]" | ||
| exit 1 | ||
| fi | ||
|
|
||
| nproc_per_node=$1 | ||
| save_path=$2 | ||
|
|
||
| # Shift the arguments so $@ refers to the rest | ||
| shift 2 | ||
|
|
||
| torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ | ||
| -m verl.trainer.fsdp_sft_trainer \ | ||
| data.train_files=$HOME/data/gsm8k/train.parquet \ | ||
| data.val_files=$HOME/data/gsm8k/test.parquet \ | ||
| data.prompt_key=extra_info \ | ||
| data.response_key=extra_info \ | ||
| data.prompt_dict_keys=['question'] \ | ||
| +data.response_dict_keys=['answer'] \ | ||
| data.micro_batch_size_per_gpu=4 \ | ||
| model.partial_pretrain=google/gemma-3-4b-it \ | ||
| model.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap=[Gemma3DecoderLayer] \ | ||
| trainer.default_local_dir=$save_path \ | ||
| trainer.project_name=gsm8k-sft-gemma \ | ||
| trainer.experiment_name=google/gemma-3-4b-it \ | ||
| trainer.total_epochs=4 \ | ||
| trainer.logger=['console','wandb'] \ | ||
| trainer.default_hdfs_dir=null $@ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,19 +197,38 @@ def _build_model_optimizer(self): | |
|
|
||
| trust_remote_code = self.config.model.trust_remote_code | ||
| # load config first | ||
| config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) | ||
| model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) | ||
|
|
||
| # For Gemma3 models, we need to use the text config | ||
| if "gemma3" in str(type(model_config)).lower(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using model name in string looks a bit adhoc. @hiyouga do you know if there're better ways to make the code more general to other multi-modal llms?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, will add that later |
||
| from transformers import Gemma3ForCausalLM, Gemma3TextConfig | ||
| model_config = Gemma3TextConfig.from_pretrained(local_model_path) | ||
| model_config.architectures = ["Gemma3ForCausalLM"] | ||
|
|
||
| if self.config.ulysses_sequence_parallel_size > 1: | ||
| assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" | ||
|
|
||
| # This may be very large | ||
| init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings, | ||
| init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, | ||
| mesh=self.device_mesh) | ||
|
|
||
| with init_context(): | ||
| self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, | ||
| config=config, | ||
| # Use Gemma3ForCausalLM specifically if it's a Gemma model | ||
| if "gemma3" in str(type(model_config)).lower(): | ||
| if self.device_mesh.get_rank() == 0: | ||
| print("Using Gemma3ForCausalLM model class") | ||
|
|
||
| from transformers import Gemma3ForCausalLM | ||
| self.model: PreTrainedModel = Gemma3ForCausalLM.from_pretrained(local_model_path, | ||
| config=model_config, | ||
| torch_dtype=torch.float32, | ||
| attn_implementation=self.config.attn_implementation, | ||
| trust_remote_code=trust_remote_code) | ||
| else: | ||
| self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path, | ||
| config=model_config, | ||
| torch_dtype=torch.float32, | ||
| attn_implementation='flash_attention_2', | ||
| attn_implementation=self.config.attn_implementation, | ||
| trust_remote_code=trust_remote_code) | ||
|
|
||
| if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you mind also providing a reference training log to https://verl.readthedocs.io/en/latest/experiment/ppo.html (ppo.rst)? It's not required but it helps the community to have more reference baselines