This repository contains a script for training Gemma3 with only using HuggingFace.
[Phi3-Vision Finetuning]
[Llama3.2-Vision Finetuning]
[Qwen2-VL Finetuning]
[Molmo Finetuning]
[Pixtral Finetune]
[SmolVLM Finetune]
- [2025/03/28] đ„Supports mixed-modality data.
- Deepspeed
- LoRA, QLoRA
- Full-finetuning
- Multi-image and video training
- Text-only data training
- Mixed-modality training
To simplfy the setting process for training, you could use the provided pre-build environments.
The settings are done in the conda env named train
.
You could find more information about the image here.
docker pull john119/vlm:v1
docker run --gpus all -it -v /host/path:/docker/path --name vlm --ipc=host john119/vlm:v1 /bin/bash
- Ubuntu 22.04
- Nvidia-Driver 550.120
- Cuda version 12.4
Install the required packages using environment.yml
.
conda env create -f environment.yaml
conda activate gemma
Note: It is strongly recommended to train Gemma3 models with the eager
attention implementation instead of flash_attention_2
Note: You could only use the text data to finetune the model.
The script requires a dataset formatted according to the LLaVA specification. The dataset should be a JSON file where each entry contains information about conversations and images. Ensure that the image paths in the dataset match the provided --image_folder
.
When using a multi-image dataset, the image tokens should all be <image>
, and the image file names should have been in a list.
Please see the example below and follow format your data.
Example for text only data
[
{
"id": "000000033471",
"conversations": [
{
"from": "human",
"value": "Identify the odd one out: Twitter, Instagram, Telegram"
},
{
"from": "gpt",
"value": "Telegram"
},
{
"from": "human",
"value": "What makes Telegram different from Twitter and Instagram?"
},
{
"from": "gpt",
"value": "Telegram is a cloud-based instant messaging app that focuses on privacy and security. Unlike Twitter and Instagram which are mainly used for following news, celebrities, and sharing images, Telegram was created as a secure messaging app for private and group communication. Telegram also offers more advanced features than Twitter and Instagram, such as the ability to schedule messages, create bots, and send encrypted messages."
}
]
}
...
]
Example for single image dataset
[
{
"id": "000000033471",
"image": "000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
{
"from": "human",
"value": "What feature can be seen on the back of the bus?"
},
{
"from": "gpt",
"value": "The back of the bus features an advertisement."
},
{
"from": "human",
"value": "Is the bus driving down the street or pulled off to the side?"
},
{
"from": "gpt",
"value": "The bus is driving down the street, which is crowded with people and other vehicles."
}
]
}
...
]
Example for multi image dataset
[
{
"id": "000000033471",
"image": ["000000033471.jpg", "000000033472.jpg"],
"conversations": [
{
"from": "human",
"value": "<image>\n<image>\nIs the perspective of the camera differnt?"
},
{
"from": "gpt",
"value": "Yes, It the perspective of the camera is different."
}
]
}
...
]
Example for video dataset
[
{
"id": "sample1",
"video": "sample1.mp4",
"conversations": [
{
"from": "human",
"value": "<video>\nWhat is going on in this video?"
},
{
"from": "gpt",
"value": "A man is walking down the road."
}
]
}
...
]
Note: Gemma3 uses a video as a sequential of images.
Note: Deepspeed zero2 is faster than zero3, however it consumes more memory. Also, most of the time zero2 is more stable than zero3.
Tip: You could use adamw_bnb_8bit
for optimizer to save memory.
To run the training script, use the following command:
bash scripts/finetune.sh
If you want to train only the language model with LoRA and perform full training for the vision model:
bash scripts/finetune_lora.sh
If you want to train both the language model and the vision model with LoRA:
bash scripts/finetune_lora_vision.sh
IMPORTANT: If you want to tune the embed_token
with LoRA, You need to tune lm_head
together.
Training arguments
--deepspeed
(str): Path to DeepSpeed config file (default: "scripts/zero2.json").--data_path
(str): Path to the LLaVA formatted training data (a JSON file). (Required)--image_folder
(str): Path to the images folder as referenced in the LLaVA formatted training data. (Required)--model_id
(str): Path to the Gemma3 model. (Required)--optim
(str): Optimizer when training (default:adamw_torch
).--output_dir
(str): Output directory for model checkpoints--num_train_epochs
(int): Number of training epochs (default: 1).--per_device_train_batch_size
(int): Training batch size per GPU per forwarding step.--gradient_accumulation_steps
(int): Gradient accumulation steps (default: 4).--freeze_vision_tower
(bool): Option to freeze vision_model (default: False).--tune_merger
(bool): Option to tune projector (default: True).--num_lora_modules
(int): Number of target modules to add LoRA (-1 means all layers).--vision_lr
(float): Learning rate for vision_model.--projector_lr
(float): Learning rate for projector.--learning_rate
(float): Learning rate for language module.--bf16
(bool): Option for using bfloat16.--fp16
(bool): Option for using fp16.--lora_enable
(bool): Option for enabling LoRA (default: False)--vision_lora
(bool): Option for including vision_tower to the LoRA module. Thelora_enable
should beTrue
to use this option. (default: False)--use_dora
(bool): Option for using DoRA instead of LoRA. Thelora_enable
should beTrue
to use this option. (default: False)--lora_namespan_exclude
(str): Exclude modules with namespans to add LoRA.--max_seq_length
(int): Maximum sequence length (default: 128K).--bits
(int): Quantization bits (default: 16).--disable_flash_attn2
(bool): Disable Flash Attention 2.--report_to
(str): Reporting tool (choices: 'tensorboard', 'wandb', 'none') (default: 'tensorboard').--logging_dir
(str): Logging directory (default: "./tf-logs").--lora_rank
(int): LoRA rank (default: 128).--lora_alpha
(int): LoRA alpha (default: 256).--lora_dropout
(float): LoRA dropout (default: 0.05).--logging_steps
(int): Logging steps (default: 1).--dataloader_num_workers
(int): Number of data loader workers (default: 4).
Note: The learning rate of vision_model
should be 10x ~ 5x smaller than the language_model
.
You can train the model using a video dataset. However, Gemma3 processes videos as a sequence of images, so youâll need to select specific frames and treat them as multiple images for training. You can set LoRA configs and use for LoRA too.
bash scripts/finetune_video.sh
If you run out of vram, you can use zero3_offload instead of zero3. However, using zero3 is preferred.
bash scripts/merge_lora.sh
Note: Remember to replace the paths in finetune.sh
or finetune_lora.sh
with your specific paths. (Also in merge_lora.sh
when using LoRA.)
Could not load library libcudnn_cnn_train.so.8. Error: /usr/local/cuda-12.1/lib/libcudnn_cnn_train.so.8: undefined symbol: _ZN5cudnn3cnn34layerNormFwd_execute_internal_implERKNS_7backend11VariantPackEP11CUstream_stRNS0_18LayerNormFwdParamsERKNS1_20NormForwardOperationEmb, version libcudnn_cnn_infer.so.8
You could run unset LD_LIBRARY_PATH
for this error.
You could see this issue
- Support for multi-image & video data
- Handle mixed-modality data
This project is licensed under the Apache-2.0 License. See the LICENSE file for details.
If you find this repository useful in your project, please consider giving a â and citing:
@misc{Gemma3-Finetuning,
author = {Yuwon Lee},
title = {Gemma3-Finetune},
year = {2025},
publisher = {GitHub},
url = {https://github.com/2U1/Gemma3-Finetune}
}
This project is based on
- LLaVA-NeXT: An amazing open-source project of LMM.
- Gemma3: Awesome pretrained MLLM by Google.