diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 160e7372126..4d9a4f97cbe 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -55,8 +55,6 @@
title: Detoxifying a Language Model
- local: multi_adapter_rl
title: Multi Adapter RLHF
- - local: training_vlm_sft
- title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
title: Examples
- sections:
- sections: # Sorted alphabetically
diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md
index 5d0c0089b6e..40987bbee2b 100644
--- a/docs/source/dataset_formats.md
+++ b/docs/source/dataset_formats.md
@@ -1033,7 +1033,7 @@ dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions
## Vision datasets
-Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
+Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
A conversational vision dataset differs from a standard conversational dataset in two key ways:
@@ -1061,4 +1061,3 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h
width="100%"
height="560px"
>
-
diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md
index 39c79cda2ae..b2ceef53bd6 100644
--- a/docs/source/sft_trainer.md
+++ b/docs/source/sft_trainer.md
@@ -13,7 +13,7 @@ This post-training method was contributed by [Younes Belkada](https://huggingfac
This example demonstrates how to train a language model using the [`SFTTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara), a compact, diverse multi-turn dataset to benchmark reasoning and generalization.
```python
-from trl import SFTTrainer, SFTConfig
+from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
@@ -91,7 +91,7 @@ This section breaks down how SFT works in practice, covering the key steps: **pr
### Preprocessing and tokenization
During training, each example is expected to contain a **text field** or a **(prompt, completion)** pair, depending on the dataset format. For more details on the expected formats, see [Dataset formats](dataset_formats).
-The `SFTTrainer` tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
+The [`SFTTrainer`] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
### Computing the loss
@@ -241,7 +241,7 @@ Unsloth is an open‑source framework for fine‑tuning and reinforcement learni
This example shows how to transform the [Qwen 3 0.6B Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) model into an instruction-following model using the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara) and a chat template from [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B). The SFT Trainer automatically handles tokenizer updates and special token configuration.
```python
-from trl import SFTTrainer, SFTConfig
+from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
@@ -280,122 +280,41 @@ Alternatively, use the structured conversation format (recommended):
## Tool Calling with SFT
-The SFT trainer fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
+The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
* The list of available tools in the `tools` column, typically provided as JSON schemas
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
-## Extending `SFTTrainer` for Vision Language Models
+## Training Vision Language Models
-`SFTTrainer` does not yet inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
-
-### Preparing the Data
-
-The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
-
-```python
-images = ["obama.png"]
-messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Who is this?"},
- {"type": "image"}
- ]
- },
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "Barack Obama"}
- ]
- },
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "What is he famous for?"}
- ]
- },
- {
- "role": "assistant",
- "content": [
- {"type": "text", "text": "He is the 44th President of the United States."}
- ]
- }
-]
-```
-
-To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
-
-```python
-from transformers import AutoProcessor
-
-processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
-print(processor.apply_chat_template(messages, tokenize=False))
-```
-
-The output will be formatted as follows:
-
-```txt
-Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
-```
-
-
-
-### A custom collator for processing multi-modal data
-
-Unlike the default behavior of [`SFTTrainer`], processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
-
-```python
-def collate_fn(examples):
- # Get the texts and images, and apply the chat template
- texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
- images = [example["images"][0] for example in examples]
-
- # Tokenize the texts and process the images
- batch = processor(images=images, text=texts, return_tensors="pt", padding=True)
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone()
- labels[labels == processor.tokenizer.pad_token_id] = -100
- batch["labels"] = labels
-
- return batch
-```
-
-We can verify that the collator works as expected by running the following code:
+[`SFTTrainer`] fully supports training Vision-Language Models (VLMs). To train a VLM, you need to provide a dataset with an additional `images` column containing the images to be processed. For more information on the expected dataset structure, see the [Dataset Format — Vision Dataset](dataset_formats#vision-dataset) section.
+An example of such a dataset is the [LLaVA Instruct Mix](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).
```python
+from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
-dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
-examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
-collated_data = collate_fn(examples)
-print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
+trainer = SFTTrainer(
+ model="Qwen/Qwen2.5-VL-3B-Instruct",
+ args=SFTConfig(max_length=None),
+ train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
+)
+trainer.train()
```
-### Training the vision-language model
+
-Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the [`SFTConfig`], specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
+For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens.
```python
-training_args.remove_unused_columns = False
-training_args.dataset_kwargs = {"skip_prepare_dataset": True}
-
-trainer = SFTTrainer(
- model=model,
- args=training_args,
- data_collator=collate_fn,
- train_dataset=train_dataset,
- processing_class=processor,
-)
+SFTConfig(max_length=None, ...)
```
-A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py).
+Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset.
-* [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
-* [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
+
## SFTTrainer
@@ -407,3 +326,11 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
## SFTConfig
[[autodoc]] SFTConfig
+
+## DataCollatorForLanguageModeling
+
+[[autodoc]] trainer.sft_trainer.DataCollatorForLanguageModeling
+
+## DataCollatorForVisionLanguageModeling
+
+[[autodoc]] trainer.sft_trainer.DataCollatorForVisionLanguageModeling
diff --git a/docs/source/training_vlm_sft.md b/docs/source/training_vlm_sft.md
deleted file mode 100644
index a5c853614f9..00000000000
--- a/docs/source/training_vlm_sft.md
+++ /dev/null
@@ -1,380 +0,0 @@
-# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
-
-
-
-## Overview
-
-This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases:
-
-- **Single Image + Text**
-- **Multi-Image + Text**
-
-This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly.
-
-We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets.
-
-## Understanding the Datasets
-
-To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task.
-
-### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text)
-
-This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input.
-
-The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**.
-
-
-
-### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text)
-
-The **FanqingM/MMIU-Benchmark** dataset consists of:
-
-- **Context:** Included in the system prompt.
-- **Question:** Provided as part of the user's input.
-- **Series of Images:** Multiple images related to the question.
-- **Answer:** The model's expected response.
-
-This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs.
-
-
-
-## Developing a Fine-Tuning Script for Multimodal SFT
-
-In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases.
-
-### Setting Up the Environment
-
-Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment:
-
-```bash
-# Install the required libraries. Further details: https://huggingface.co/docs/trl/installation
-pip install -U -q trl bitsandbytes peft hf_xet tensorboard
-```
-
-Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required.
-
-If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it.
-
-To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account.
-
-```bash
-huggingface-cli login
-```
-
-### **Loading the Data**
-
-As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent.
-
-This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario.
-
-#### **Single Image + Text**
-
-
-
-In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`.
-
-```python
-from datasets import load_dataset
-
-dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
-
-# Load Dataset
-dataset = load_dataset(dataset_name)
-```
-
-#### **Multi-Image + Text (or Interleaving)**
-
-
-
-Gemma 3 also supports **Multi-Image + Text** scenarios, where:
-
-- The model receives a **list of images** alongside a user message.
-- The model processes **interleaved images and text** within a conversation.
-
-For this dataset, some preprocessing is required before training.
-
-```python
-from datasets import load_dataset
-
-dataset_name = "FanqingM/MMIU-Benchmark"
-
-# Load Dataset
-dataset = load_dataset(dataset_name)
-```
-
-After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look:
-
-```python
-{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
-{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
-{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
-```
-
-Here, `images_list` is a list of images:
-
-```python
-images_list = [
- {"type": "image", "image": },
- {"type": "image", "image": },
- {"type": "image", "image": },
- {"type": "image", "image": },
- {"type": "image", "image": },
-]
-```
-
-This structure can be translated into code like this:
-
-```python
-import os
-import zipfile
-import io
-from datasets import DatasetDict
-from huggingface_hub import hf_hub_download, list_repo_files
-from PIL import Image
-
-dataset_train_split = "test"
-
-def format_data(samples: dict[str, any]) -> dict[str, list]:
- formatted_samples = {"messages": []}
- for cont in range(len(samples["question"])):
- images = []
- for img_path in samples["input_image_path"][cont]:
- try:
- with open(img_path, "rb") as f:
- img_bytes = f.read()
- image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
- images.append({"type": "image", "image": image})
- except Exception as e:
- print(f"Error processing image {img_path}: {e}")
- continue
-
- formatted_samples["messages"].append(
- [
- {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
- {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
- {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
- ]
- )
- return formatted_samples
-
-# For multi-image example
-def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
- all_files = list_repo_files(dataset_name, repo_type="dataset")
- zip_files = [f for f in all_files if f.endswith(".zip")]
-
- for zip_filename in zip_files:
- zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
- extract_folder = zip_filename.replace(".zip", "")
- os.makedirs(extract_folder, exist_ok=True)
-
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
- zip_ref.extractall(extract_folder)
-
- dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
- return dataset
-
-dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
-```
-
-With this, your **Multi-Image + Text** dataset is now prepared for training.
-
-### **Preparing for Training**
-
-We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models.
-
-To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model.
-
-```python
-import torch
-from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
-
-model_id = "google/gemma-3-4b-it"
-
-# BitsAndBytesConfig int-4 config
-bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_use_double_quant=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
- bnb_4bit_quant_storage=torch.bfloat16,
-)
-
-# Load model and tokenizer
-model = AutoModelForImageTextToText.from_pretrained(
- model_id,
- device_map="auto",
- torch_dtype=torch.bfloat16,
- attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
- quantization_config=bnb_config
-)
-processor = AutoProcessor.from_pretrained(model_id)
-processor.tokenizer.padding_side = "right"
-```
-
-Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs).
-
-```python
-from peft import LoraConfig, get_peft_model
-
-# Configure QLoRA
-peft_config = LoraConfig(
- lora_alpha=16,
- lora_dropout=0.05,
- r=16,
- bias="none",
- target_modules="all-linear",
- task_type="CAUSAL_LM",
- modules_to_save=[
- "lm_head",
- "embed_tokens",
- ],
-)
-```
-
-With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs.
-
-```python
-from trl import SFTConfig
-
-training_args = SFTConfig(
- output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
- num_train_epochs=1, # Set the number of epochs to train the model.
- per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
- gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
- gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
- optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
- save_strategy="epoch", # Save checkpoints at the end of each epoch.
- learning_rate=2e-05, # Learning rate for training.
- bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
- push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training.
- report_to="tensorboard", # Automatically report metrics to tensorboard.
- gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
- dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
- remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
-)
-```
-
-The `collate_fn` is responsible for processing and preparing individual examples to form a batch.
-
-Each example in the batch undergoes the following steps:
-
-1. The **chat template** is applied to the text.
-2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors.
-3. The **labels** for training are set as the `input_ids` of the example.
-4. Certain **special tokens** are **masked (ignored)** during loss computation:
- - `pad_token_id`
- - ``
- - `` (corresponding to ID `262144`)
-
-This process is similar across different dataset types, with a minor variation in how images are handled:
-
-- **Single Image + Text** → A **list of images** is directly processed.
-- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images.
-
-```python
-from PIL import Image
-
-# For multi-image cases
-def process_vision_info(messages: list[dict]) -> list[Image.Image]:
- image_inputs = []
- for msg in messages:
- content = msg.get("content", [])
- if not isinstance(content, list):
- content = [content]
-
- for element in content:
- if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
- if "image" in element:
- image = element["image"]
- else:
- image = element
- if image is not None:
- image = Image.open(io.BytesIO(image["bytes"]))
- image_inputs.append(image.convert("RGB"))
- return image_inputs
-
-def collate_fn(examples):
- texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
- if "images" in examples[0]: # single-image
- images = [
- [img.convert("RGB") for img in example["images"]]
- for example in examples
- ]
- else: # multi-image
- images = [process_vision_info(example["messages"]) for example in examples]
-
- # Tokenize the texts and process the images
- batch = processor(
- images=images, text=texts, return_tensors="pt", padding=True
- ) # Encode texts and images into tensors
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone() # Clone input IDs for labels
- # Mask image tokens
- image_token_id = [
- processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
- ]
- # Mask tokens for not being used in the loss computation
- labels[labels == processor.tokenizer.pad_token_id] = -100
- labels[labels == image_token_id] = -100
- labels[labels == 262144] = -100
-
- batch["labels"] = labels
- return batch # Return the prepared batch
-```
-
-### **Training the Model**
-
-With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process.
-
-``` python
-# Training
-from trl import SFTTrainer
-
-trainer = SFTTrainer(
- model=model,
- args=training_args,
- data_collator=collate_fn,
- train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
- processing_class=processor,
- peft_config=peft_config,
-)
-
-trainer.train()
-
-# Save the final model
-trainer.save_model()
-```
-
-We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration.
-
-
-### Results
-
-During and after training, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example:
-
-* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft)
-
-* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark)
-
-## Limitations
-
-Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results.
-
-## References
-
-For further reading and complementary resources, check out the following:
-
-- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
-- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)
-
diff --git a/examples/datasets/llava_instruct_mix.py b/examples/datasets/llava_instruct_mix.py
new file mode 100644
index 00000000000..a819ae38b5a
--- /dev/null
+++ b/examples/datasets/llava_instruct_mix.py
@@ -0,0 +1,107 @@
+# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import ast
+from dataclasses import dataclass, field
+from typing import Optional
+
+from datasets import load_dataset
+from huggingface_hub import ModelCard
+from transformers import HfArgumentParser
+
+
+@dataclass
+class ScriptArguments:
+ r"""
+ Arguments for the script.
+
+ Args:
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether to push the dataset to the Hugging Face Hub.
+ repo_id (`str`, *optional*, defaults to `"trl-lib/llava-instruct-mix"`):
+ Hugging Face repository ID to push the dataset to.
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
+ Number of workers to use for dataset processing.
+ """
+
+ push_to_hub: bool = field(
+ default=False,
+ metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
+ )
+ repo_id: str = field(
+ default="trl-lib/llava-instruct-mix",
+ metadata={"help": "Hugging Face repository ID to push the dataset to."},
+ )
+ dataset_num_proc: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of workers to use for dataset processing."},
+ )
+
+
+def process_example(example):
+ messages = []
+ for message in ast.literal_eval(example["conversations"]):
+ content = message["value"]
+ content = content.replace("", "").strip()
+ role = "user" if message["from"] == "human" else "assistant"
+ messages.append({"role": role, "content": content})
+ return {"messages": messages, "images": [example["image"]]}
+
+
+def filter_long_examples(example):
+ total_length = sum(len(msg["content"]) for msg in example["messages"])
+ return total_length <= 1000
+
+
+model_card = ModelCard("""
+---
+tags: [trl]
+---
+
+# LLaVA Instruct Mix
+
+## Summary
+
+The LLaVA Instruct Mix dataset is a processed version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix).
+
+## Data Structure
+
+- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
+- **Type**: [Language-modeling](https://huggingface.co/docs/trl/main/dataset_formats#language-modeling)
+
+Columns:
+- `"images"`: The image associated with the text.
+- `"messages"`: A list of messages in the conversation.
+
+This structure allows models to learn from the context of the conversation, enhancing their understanding of how to generate descriptive text based on visual inputs.
+
+## Generation script
+
+The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/llava_instruct_mix.py).
+""")
+
+if __name__ == "__main__":
+ parser = HfArgumentParser(ScriptArguments)
+ script_args = parser.parse_args_into_dataclasses()[0]
+
+ dataset = load_dataset("theblackcat102/llava-instruct-mix")
+
+ dataset = dataset.map(
+ process_example, remove_columns=["conversations", "image"], num_proc=script_args.dataset_num_proc
+ )
+ dataset = dataset.filter(filter_long_examples, num_proc=script_args.dataset_num_proc)
+
+ if script_args.push_to_hub:
+ dataset.push_to_hub(script_args.repo_id, num_proc=script_args.dataset_num_proc)
+ model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py
index 0cd926fe117..aca691d5a5d 100644
--- a/examples/scripts/sft_video_llm.py
+++ b/examples/scripts/sft_video_llm.py
@@ -178,7 +178,6 @@ class CustomScriptArguments(ScriptArguments):
# Configure training args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
- training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# Load dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train")
diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py
index 53618babc27..dd961e48b04 100644
--- a/examples/scripts/sft_vlm.py
+++ b/examples/scripts/sft_vlm.py
@@ -23,27 +23,37 @@
pip install pillow
# Tested on 8x H100 GPUs
-accelerate launch
- --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
+accelerate launch \
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/sft_vlm.py \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
- --per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
- --output_dir sft-llava-1.5-7b-hf \
- --torch_dtype bfloat16 \
- --gradient_checkpointing
+ --output_dir LLaVA-1.5-7B-SFT \
+ --torch_dtype bfloat16
For LLaVA-NeXT, use: (requires transformers>=4.45)
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1)
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
+
+accelerate launch \
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
+ examples/scripts/sft_vlm.py \
+ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
+ --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \
+ --per_device_train_batch_size 1 \
+ --gradient_accumulation_steps 1 \
+ --output_dir SmolVLM-SFT \
+ --torch_dtype bfloat16 \
+ --use_peft \
+ --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj
"""
import torch
from datasets import load_dataset
-from transformers import AutoModelForImageTextToText, AutoProcessor, LlavaForConditionalGeneration
+from transformers import AutoModelForImageTextToText
from trl import (
ModelConfig,
@@ -61,8 +71,7 @@
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
- training_args.remove_unused_columns = False
- training_args.dataset_kwargs = {"skip_prepare_dataset": True}
+ training_args.max_length = None
################
# Model, Tokenizer & Processor
@@ -78,38 +87,11 @@
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
- processor = AutoProcessor.from_pretrained(
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
- )
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
- ################
- # Create a data collator to encode text and image pairs
- ################
- def collate_fn(examples):
- # Get the texts and images, and apply the chat template
- texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
- images = [example["images"] for example in examples]
- if isinstance(model, LlavaForConditionalGeneration):
- # LLava1.5 does not support multiple images
- images = [image[0] for image in images]
-
- # Tokenize the texts and process the images
- batch = processor(images=images, text=texts, return_tensors="pt", padding=True)
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone()
- labels[labels == processor.tokenizer.pad_token_id] = -100 #
- # Ignore the image token index in the loss computation (model specific)
- image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
- labels[labels == image_token_id] = -100
- batch["labels"] = labels
-
- return batch
-
################
# Dataset
################
@@ -121,10 +103,8 @@ def collate_fn(examples):
trainer = SFTTrainer(
model=model,
args=training_args,
- data_collator=collate_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
- processing_class=processor,
peft_config=get_peft_config(model_args),
)
@@ -134,5 +114,3 @@ def collate_fn(examples):
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
- if trainer.accelerator.is_main_process:
- processor.push_to_hub(training_args.hub_model_id)
diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py
index fcda0c73863..6cb94bee9fa 100644
--- a/examples/scripts/sft_vlm_gemma3.py
+++ b/examples/scripts/sft_vlm_gemma3.py
@@ -20,7 +20,7 @@
# ///
"""
-Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
+Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
@@ -28,14 +28,13 @@
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
- --gradient_accumulation_steps 1 \
- --output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
+ --output_dir Gemma-3-4B-SFT-MMIU \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
-Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
+Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
@@ -44,8 +43,7 @@
--dataset_train_split test \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
- --gradient_accumulation_steps 1 \
- --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
+ --output_dir Gemma-3-4B-SFT-MMIU \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear
@@ -60,7 +58,7 @@
from datasets import DatasetDict, load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
-from transformers import AutoModelForImageTextToText, AutoProcessor
+from transformers import AutoModelForImageTextToText
from trl import (
ModelConfig,
@@ -119,7 +117,7 @@ def format_data(samples: dict[str, any]) -> dict[str, list]:
# For multi-image example
-def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
+def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
all_files = list_repo_files(dataset_name, repo_type="dataset")
zip_files = [f for f in all_files if f.endswith(".zip")]
@@ -139,8 +137,7 @@ def main():
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
- training_args.remove_unused_columns = False
- training_args.dataset_kwargs = {"skip_prepare_dataset": True}
+ training_args.max_length = None
################
# Model, Tokenizer & Processor
@@ -156,50 +153,16 @@ def main():
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
- processor = AutoProcessor.from_pretrained(
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
- )
- processor.tokenizer.padding_side = "right"
-
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
- def collate_fn(examples):
- texts = [
- processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip()
- for example in examples
- ]
- if "images" in examples[0]: # single-image
- images = [[img.convert("RGB") for img in example["images"]] for example in examples]
- else: # multi-image
- images = [process_vision_info(example["messages"]) for example in examples]
-
- # Tokenize the texts and process the images
- batch = processor(
- images=images, text=texts, return_tensors="pt", padding=True
- ) # Encode texts and images into tensors
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone() # Clone input IDs for labels
- # Mask image tokens
- image_token_id = [
- processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
- ]
- # Mask tokens for not being used in the loss computation
- labels[labels == processor.tokenizer.pad_token_id] = -100
- labels[labels == image_token_id] = -100
- labels[labels == 262144] = -100
-
- batch["labels"] = labels
- return batch # Return the prepared batch
-
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
- dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split)
+ dataset = prepare_dataset(dataset, script_args.dataset_name)
################
# Training
@@ -207,10 +170,8 @@ def collate_fn(examples):
trainer = SFTTrainer(
model=model,
args=training_args,
- data_collator=collate_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
- processing_class=processor,
peft_config=get_peft_config(model_args),
)
@@ -220,8 +181,6 @@ def collate_fn(examples):
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
- if trainer.accelerator.is_main_process:
- processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":
diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py
deleted file mode 100644
index d6ff2da1974..00000000000
--- a/examples/scripts/sft_vlm_smol_vlm.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# /// script
-# dependencies = [
-# "trl @ git+https://github.com/huggingface/trl.git",
-# "Pillow>=9.4.0",
-# ]
-# ///
-
-"""
-pip install pillow
-
-# Tested on 8x H100 GPUs
-accelerate launch
- --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
- sft_vlm_smol_vlm.py \
- --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
- --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \
- --per_device_train_batch_size 1 \
- --gradient_accumulation_steps 1 \
- --output_dir sft-smol-vlm-hf \
- --torch_dtype bfloat16 \
- --gradient_checkpointing \
- --use_peft \
- --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj
-
-For LLaVA-NeXT, use: (requires transformers>=4.45)
- --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
-
-For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1)
- --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
-"""
-
-import torch
-from datasets import load_dataset
-from transformers import (
- AutoModelForImageTextToText,
- AutoProcessor,
- Idefics3ForConditionalGeneration,
- LlavaForConditionalGeneration,
-)
-
-from trl import (
- ModelConfig,
- ScriptArguments,
- SFTConfig,
- SFTTrainer,
- TrlParser,
- get_kbit_device_map,
- get_peft_config,
- get_quantization_config,
-)
-
-
-if __name__ == "__main__":
- parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
- script_args, training_args, model_args = parser.parse_args_and_config()
- training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
- training_args.remove_unused_columns = False
- training_args.dataset_kwargs = {"skip_prepare_dataset": True}
-
- ################
- # Model, Tokenizer & Processor
- ################
- torch_dtype = (
- model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
- )
- quantization_config = get_quantization_config(model_args)
- model_kwargs = dict(
- revision=model_args.model_revision,
- attn_implementation=model_args.attn_implementation,
- torch_dtype=torch_dtype,
- device_map=get_kbit_device_map() if quantization_config is not None else None,
- quantization_config=quantization_config,
- )
- processor = AutoProcessor.from_pretrained(
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
- )
-
- model = AutoModelForImageTextToText.from_pretrained(
- model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
- )
-
- ################
- # Create a data collator to encode text and image pairs
- ################
- def collate_fn(examples):
- # Get the texts and images, and apply the chat template
- texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
- images = [example["images"] for example in examples]
- if isinstance(model, LlavaForConditionalGeneration):
- # LLava1.5 does not support multiple images
- images = [image[0] for image in images]
-
- # Tokenize the texts and process the images
- batch = processor(images=images, text=texts, return_tensors="pt", padding=True)
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone()
- labels[labels == processor.tokenizer.pad_token_id] = -100 #
- # Ignore the image token index in the loss computation (model specific)
- if isinstance(model, Idefics3ForConditionalGeneration):
- image_token_id = processor.tokenizer.additional_special_tokens_ids[
- processor.tokenizer.additional_special_tokens.index("")
- ]
- else:
- image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
- labels[labels == image_token_id] = -100
- batch["labels"] = labels
-
- return batch
-
- ################
- # Dataset
- ################
- dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
-
- ################
- # Training
- ################
- trainer = SFTTrainer(
- model=model,
- args=training_args,
- data_collator=collate_fn,
- train_dataset=dataset[script_args.dataset_train_split],
- eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
- processing_class=processor,
- peft_config=get_peft_config(model_args),
- )
-
- trainer.train()
-
- # Save and push to hub
- trainer.save_model(training_args.output_dir)
- if training_args.push_to_hub:
- trainer.push_to_hub(dataset_name=script_args.dataset_name)
- if trainer.accelerator.is_main_process:
- processor.push_to_hub(training_args.hub_model_id)
diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py
index 27f51d6293f..55b5e7105c9 100644
--- a/scripts/generate_tiny_models.py
+++ b/scripts/generate_tiny_models.py
@@ -16,9 +16,11 @@
# `trl-internal-testing` organization.
# This script is meant to be run when adding new tiny model to the TRL library.
+import torch
from huggingface_hub import HfApi, ModelCard
from torch import nn
from transformers import (
+ AutoConfig,
AutoProcessor,
AutoTokenizer,
BartConfig,
@@ -35,7 +37,6 @@
FalconMambaForCausalLM,
Gemma2Config,
Gemma2ForCausalLM,
- Gemma3Config,
Gemma3ForConditionalGeneration,
GemmaConfig,
GemmaForCausalLM,
@@ -47,18 +48,17 @@
GptOssForCausalLM,
Idefics2Config,
Idefics2ForConditionalGeneration,
+ Idefics3ForConditionalGeneration,
+ InternVLForConditionalGeneration,
LlamaConfig,
LlamaForCausalLM,
LlamaForSequenceClassification,
- LlavaConfig,
LlavaForConditionalGeneration,
- LlavaNextConfig,
LlavaNextForConditionalGeneration,
MistralConfig,
MistralForCausalLM,
OPTConfig,
OPTForCausalLM,
- PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
Phi3Config,
Phi3ForCausalLM,
@@ -74,7 +74,6 @@
Qwen3ForSequenceClassification,
Qwen3MoeConfig,
Qwen3MoeForCausalLM,
- SmolVLMConfig,
SmolVLMForConditionalGeneration,
T5Config,
T5ForConditionalGeneration,
@@ -98,7 +97,7 @@
api = HfApi()
-def push_to_hub(model, tokenizer, prefix=None, suffix=None):
+def push_to_hub(model, tokenizer, prefix=None, suffix=None, force=False):
model_class_name = model.__class__.__name__
content = MODEL_CARD.format(model_class_name=model_class_name)
model_card = ModelCard(content)
@@ -108,7 +107,7 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None):
if suffix is not None:
repo_id += f"-{suffix}"
- if api.repo_exists(repo_id):
+ if api.repo_exists(repo_id) and not force:
print(f"Model {repo_id} already exists, skipping")
else:
model.push_to_hub(repo_id)
@@ -179,7 +178,7 @@ def init_weights_tiny_model(model):
revision = "refs/pr/14" if model_id == "Qwen/Qwen3-8B" else "main" # chat template with {% generation %}
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
config = config_class(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
@@ -197,7 +196,7 @@ def init_weights_tiny_model(model):
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
@@ -214,7 +213,7 @@ def init_weights_tiny_model(model):
# Two slightly bigger models, required for vLLM testing
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
config = Qwen2Config(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM
num_attention_heads=4,
num_key_value_heads=2,
@@ -226,7 +225,7 @@ def init_weights_tiny_model(model):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
config = Qwen3Config(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM
num_attention_heads=4,
num_key_value_heads=2,
@@ -244,7 +243,7 @@ def init_weights_tiny_model(model):
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
@@ -263,7 +262,7 @@ def init_weights_tiny_model(model):
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
- vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
+ vocab_size=len(tokenizer.vocab),
d_model=16,
encoder_layers=2,
decoder_layers=2,
@@ -279,51 +278,43 @@ def init_weights_tiny_model(model):
# Vision Language Models
-for model_id, config_class, model_class in [
- ("google/gemma-3-4b-it", Gemma3Config, Gemma3ForConditionalGeneration),
- ("google/paligemma-3b-pt-224", PaliGemmaConfig, PaliGemmaForConditionalGeneration),
- ("HuggingFaceM4/idefics2-8b", Idefics2Config, Idefics2ForConditionalGeneration),
- ("HuggingFaceTB/SmolVLM2-2.2B-Instruct", SmolVLMConfig, SmolVLMForConditionalGeneration),
- ("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlavaForConditionalGeneration),
- ("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, LlavaNextForConditionalGeneration),
- ("Qwen/Qwen2-VL-2B-Instruct", Qwen2VLConfig, Qwen2VLForConditionalGeneration),
- ("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration),
+for model_id, model_class in [
+ ("google/gemma-3-4b-it", Gemma3ForConditionalGeneration),
+ ("google/paligemma-3b-pt-224", PaliGemmaForConditionalGeneration),
+ ("HuggingFaceM4/idefics2-8b", Idefics2ForConditionalGeneration),
+ ("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3ForConditionalGeneration),
+ ("HuggingFaceTB/SmolVLM2-2.2B-Instruct", SmolVLMForConditionalGeneration),
+ ("llava-hf/llava-1.5-7b-hf", LlavaForConditionalGeneration),
+ ("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextForConditionalGeneration),
+ ("OpenGVLab/InternVL3-8B-hf", InternVLForConditionalGeneration),
+ ("Qwen/Qwen2-VL-2B-Instruct", Qwen2VLForConditionalGeneration),
+ ("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLForConditionalGeneration),
]:
processor = AutoProcessor.from_pretrained(model_id)
- kwargs = {}
- text_kwargs = {}
- vision_kwargs = {}
- if config_class == PaliGemmaConfig:
- kwargs["projection_dim"] = 8
- if config_class in [LlavaConfig, LlavaNextConfig, PaliGemmaConfig]:
- vision_kwargs["projection_dim"] = 8
- if config_class in [LlavaConfig, LlavaNextConfig]:
- vision_kwargs["image_size"] = 336
- vision_kwargs["patch_size"] = 14
- if config_class in [Qwen2VLConfig, Qwen2_5_VLConfig]:
- kwargs["vision_start_token_id"] = 151652
- text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [1]}
- vision_kwargs["depth"] = 4
- vision_kwargs["embed_dim"] = 64
+ config = AutoConfig.from_pretrained(model_id)
+
+ config.text_config.num_hidden_layers = 2
+ config.text_config.hidden_size = 16
+ config.text_config.num_attention_heads = 4
+ config.text_config.num_key_value_heads = 2
+
+ config.vision_config.num_hidden_layers = 2
+ config.vision_config.hidden_size = 16
+ config.vision_config.num_attention_heads = 4
+ config.vision_config.num_key_value_heads = 2
+
+ if isinstance(config, (Qwen2VLConfig)):
+ config.vision_config.depth = 2
+
+ if isinstance(config, (Qwen2VLConfig, Qwen2_5_VLConfig)):
+ config.text_config.rope_scaling["mrope_section"] = [2]
+
+ if isinstance(config, (Qwen2_5_VLConfig)):
+ config.vision_config.out_hidden_size = 16
+
+ if isinstance(config, Idefics2Config):
+ config.perceiver_config.hidden_size = 16
+
+ model = model_class(config).to(dtype=torch.bfloat16)
- config = config_class(
- text_config=dict(
- vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder),
- hidden_size=8,
- num_attention_heads=4,
- num_key_value_heads=2,
- num_hidden_layers=2,
- intermediate_size=32,
- **text_kwargs,
- ),
- vision_config=dict(
- hidden_size=16,
- num_attention_heads=4,
- num_hidden_layers=2,
- intermediate_size=32,
- **vision_kwargs,
- ),
- **kwargs,
- )
- model = model_class(config)
push_to_hub(model, processor, "tiny")
diff --git a/scripts/generate_zen_image_dataset.py b/scripts/generate_zen_image_dataset.py
index 5795c8d8928..378ed16a5c3 100644
--- a/scripts/generate_zen_image_dataset.py
+++ b/scripts/generate_zen_image_dataset.py
@@ -15,7 +15,7 @@
from dataclasses import dataclass, field
import numpy as np
-from datasets import Dataset, Features, Image, Sequence, Value
+from datasets import Dataset, Features, Image, List, Sequence, Value
from transformers import HfArgumentParser
@@ -75,9 +75,9 @@ def main(test_size, push_to_hub, repo_id):
"If the implementation is easy to explain, it may be a good idea.",
"Namespaces are one honking great idea -- let's do more of those!",
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_language_modeling_dataset = Dataset.from_dict(data, features=Features(text=Value("string"), image=Image()))
+ standard_language_modeling_dataset = Dataset.from_dict(data, features=Features(text=Value("string"), images=List(Image())))
standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_language_modeling_dataset.push_to_hub(repo_id, config_name="standard_language_modeling")
@@ -105,9 +105,9 @@ def main(test_size, push_to_hub, repo_id):
"If the implementation is easy",
"Namespaces are one honking great",
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), image=Image()))
+ standard_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), images=List(Image())))
standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_prompt_only_dataset.push_to_hub(repo_id, config_name="standard_prompt_only")
@@ -156,9 +156,9 @@ def main(test_size, push_to_hub, repo_id):
" to explain, it may be a good idea.",
" idea -- let's do more of those!",
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), image=Image()))
+ standard_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), images=List(Image())))
standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_prompt_completion_dataset.push_to_hub(repo_id, config_name="standard_prompt_completion")
@@ -228,9 +228,9 @@ def main(test_size, push_to_hub, repo_id):
" it's probably magic.",
" watermelon -- let's plant some!",
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), chosen=Value("string"), rejected=Value("string"), image=Image()))
+ standard_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), chosen=Value("string"), rejected=Value("string"), images=List(Image())))
standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_preference_dataset.push_to_hub(repo_id, config_name="standard_preference")
@@ -279,9 +279,9 @@ def main(test_size, push_to_hub, repo_id):
"If the implementation is easy it's probably magic.",
"Namespaces are one honking great watermelon -- let's plant some!",
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Value("string"), rejected=Value("string"), image=Image()))
+ standard_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Value("string"), rejected=Value("string"), images=List(Image())))
{'prompt': Value(dtype='string'), 'completions': Sequence(feature=Value(dtype='string'), length=-1), 'labels': Sequence(feature=Value(dtype='bool'), length=-1)}
standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
@@ -332,9 +332,9 @@ def main(test_size, push_to_hub, repo_id):
" watermelon -- let's plant some!",
],
"label": [True, False, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, False, False],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), label=Value("bool"), image=Image()))
+ standard_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), label=Value("bool"), images=List(Image())))
standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference")
@@ -403,9 +403,9 @@ def main(test_size, push_to_hub, repo_id):
[True, True],
[False]
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- standard_stepwise_supervision_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completions=Sequence(Value("string")), labels=Sequence(Value("bool")), image=Image()))
+ standard_stepwise_supervision_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completions=Sequence(Value("string")), labels=Sequence(Value("bool")), images=List(Image())))
standard_stepwise_supervision_dataset = standard_stepwise_supervision_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
standard_stepwise_supervision_dataset.push_to_hub(repo_id, config_name="standard_stepwise_supervision")
@@ -433,9 +433,9 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it may be a good idea."}],
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}],
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_language_modeling_dataset = Dataset.from_dict(data, features=Features(messages=Message, image=Image()))
+ conversational_language_modeling_dataset = Dataset.from_dict(data, features=Features(messages=Message, images=List(Image())))
conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling")
@@ -463,9 +463,9 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}],
[{"role": "user", "content": "Any great ideas?"}],
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Message, image=Image()))
+ conversational_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Message, images=List(Image())))
conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only")
@@ -514,9 +514,9 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "assistant", "content": "It means it may be a good idea."}],
[{"role": "assistant", "content": "Namespaces are one honking great idea."}],
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, image=Image()))
+ conversational_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, images=List(Image())))
conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion")
@@ -586,9 +586,9 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "assistant", "content": "It means it's a bad idea."}],
[{"role": "assistant", "content": "Recursion."}],
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, chosen=Message, rejected=Message, image=Image()))
+ conversational_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, chosen=Message, rejected=Message, images=List(Image())))
conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference")
@@ -637,9 +637,9 @@ def main(test_size, push_to_hub, repo_id):
[{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it's a bad idea."}],
[{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Recursion."}],
],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Message, rejected=Message, image=Image()))
+ conversational_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Message, rejected=Message, images=List(Image())))
conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="conversational_implicit_prompt_preference")
@@ -689,9 +689,9 @@ def main(test_size, push_to_hub, repo_id):
[{'role': 'assistant', 'content': 'Namespaces are one honking great idea.'}],
],
"label": [True, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, True, True],
- "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes],
+ "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes],
}
- conversational_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, label=Value("bool"), image=Image()))
+ conversational_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, label=Value("bool"), images=List(Image())))
conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False)
if push_to_hub:
conversational_unpaired_preference_dataset.push_to_hub(repo_id, config_name="conversational_unpaired_preference")
diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py
index 512d94ed6b1..bffb629840b 100644
--- a/tests/test_dpo_trainer.py
+++ b/tests/test_dpo_trainer.py
@@ -1417,7 +1417,7 @@ def test_train_with_iterable_dataset(self):
class DPOVisionTrainerTester(TrlTestCase):
@parameterized.expand(
[
- # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), # device issue from transformers
+ # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index ead8203b463..1e34346fb1d 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -14,17 +14,11 @@
import pathlib
-import numpy as np
+import pytest
import torch
-from datasets import Dataset, Image, Sequence, load_dataset
+from datasets import Dataset, load_dataset
from parameterized import parameterized
-from transformers import (
- AutoModelForCausalLM,
- AutoProcessor,
- AutoTokenizer,
- LlavaForConditionalGeneration,
- is_vision_available,
-)
+from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_peft, require_vision
from transformers.utils import is_peft_available
@@ -37,9 +31,6 @@
if is_peft_available():
from peft import LoraConfig, PeftModel, get_peft_model
-if is_vision_available():
- from PIL import Image as PILImage
-
def formatting_prompts_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
@@ -277,50 +268,6 @@ def setUp(self):
"trl-internal-testing/zen", "standard_prompt_completion"
)
- if is_vision_available():
- self.dummy_vsft_instruction_dataset = Dataset.from_dict(
- {
- "messages": [
- [
- {
- "role": "user",
- "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": "It is random noise."}],
- },
- {
- "role": "user",
- "content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": "2"}],
- },
- ],
- [
- {
- "role": "user",
- "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
- },
- {
- "role": "assistant",
- "content": [{"type": "text", "text": "It is random noise."}],
- },
- ],
- ],
- "images": [
- [PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")],
- [PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")],
- ],
- }
- )
- self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image()))
- self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column(
- "images", Sequence(Image())
- )
-
def test_uncorrect_data(self):
# Shoud work as SFTTrainer natively supports conversational lm dataset
training_args = SFTConfig(
@@ -501,86 +448,6 @@ def test_no_packing(self):
self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"]))
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))
- @require_vision
- def test_skip_prepare_dataset(self):
- training_args = SFTConfig(
- output_dir=self.tmp_dir,
- per_device_train_batch_size=2,
- remove_unused_columns=False,
- dataset_kwargs={"skip_prepare_dataset": True},
- report_to="none",
- )
-
- trainer = SFTTrainer(
- model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
- args=training_args,
- train_dataset=self.dummy_vsft_instruction_dataset,
- )
- self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features)
-
- def test_skip_prepare_dataset_with_no_packing(self):
- training_args = SFTConfig(
- output_dir=self.tmp_dir,
- per_device_train_batch_size=2,
- remove_unused_columns=False,
- packing=False,
- dataset_kwargs={"skip_prepare_dataset": True},
- report_to="none",
- )
-
- trainer = SFTTrainer(
- model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
- args=training_args,
- train_dataset=self.dummy_dataset,
- )
- self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features)
-
- @require_vision
- def test_llava(self):
- training_args = SFTConfig(
- output_dir=self.tmp_dir,
- remove_unused_columns=False,
- dataset_kwargs={"skip_prepare_dataset": True},
- report_to="none",
- )
- tiny_llava = LlavaForConditionalGeneration.from_pretrained(
- "trl-internal-testing/tiny-LlavaForConditionalGeneration"
- )
- processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration")
-
- processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious
-user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's
-questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for
-item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image'
-%}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if
-add_generation_prompt %}ASSISTANT: {% endif %}"""
-
- def collate_fn(examples):
- # Get the texts and images, and apply the chat template
- texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
- images = [example["images"][0] for example in examples]
-
- # Tokenize the texts and process the images
- batch = processor(images=images, text=texts, return_tensors="pt", padding=True)
-
- # The labels are the input_ids, and we mask the padding tokens in the loss computation
- labels = batch["input_ids"].clone()
- labels[labels == processor.tokenizer.pad_token_id] = -100
- batch["labels"] = labels
-
- return batch
-
- trainer = SFTTrainer(
- model=tiny_llava,
- args=training_args,
- data_collator=collate_fn,
- train_dataset=self.dummy_vsft_instruction_dataset,
- )
-
- trainer.train()
-
- self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
-
# This new tester aims to replace the first one at some point
class SFTTrainerTester2(TrlTestCase):
@@ -1404,3 +1271,92 @@ def test_tag_added_peft(self):
for tag in ["sft", "trl"]:
self.assertIn(tag, trainer.model.model_tags)
+
+ @parameterized.expand(
+ [
+ ("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
+ # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
+ # ("trl-internal-testing/tiny-Idefics3ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
+ ("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
+ ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
+ ("trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",),
+ ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
+ # ("trl-internal-testing/tiny-SmolVLMForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975
+ ]
+ )
+ @require_vision
+ def test_train_vlm(self, model_id):
+ # Get the dataset
+ dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
+
+ # Initialize the trainer
+ training_args = SFTConfig(
+ output_dir=self.tmp_dir,
+ max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
+ report_to="none",
+ )
+ trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset)
+
+ # Save the initial parameters to compare them later
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ # Train the model
+ trainer.train()
+
+ # Check that the training loss is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ # For some reason, these params are not updated. This is probably not related to TRL, but to
+ # the model itself. We should investigate this further, but for now we just skip these params.
+ # fmt: off
+ if (
+ model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or
+ model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
+ model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or
+ model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or
+ model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n
+ ):
+ # fmt: on
+ continue
+ self.assertFalse(
+ torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
+ )
+
+ # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
+ # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
+ @pytest.mark.slow
+ @require_vision
+ def test_train_vlm_gemma_3n(self):
+ # Get the dataset
+ dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
+
+ # Initialize the trainer
+ training_args = SFTConfig(
+ output_dir=self.tmp_dir,
+ max_length=None,
+ per_device_train_batch_size=1,
+ gradient_checkpointing=True,
+ model_init_kwargs={"torch_dtype": "bfloat16"},
+ report_to="none",
+ )
+ trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset)
+
+ # Save the initial parameters to compare them later
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ # Train the model
+ trainer.train()
+
+ # Check that the training loss is not None
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ if "model.vision_tower.timm_model.conv_stem.bn.weight" in n:
+ # This parameter is not updated, not sure why at this point.
+ continue
+ self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
index 23b334ecad8..12cc5443757 100644
--- a/trl/trainer/dpo_trainer.py
+++ b/trl/trainer/dpo_trainer.py
@@ -1549,7 +1549,7 @@ def concatenated_forward(
loss_mask = loss_mask[:, -logits_to_keep:]
if logits.shape[:2] != labels.shape[:2]:
- # for llava, the returned logits include the image tokens (placed before the text tokens)
+ # for LLaVA, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]
diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py
index b237867f913..108547fcd79 100644
--- a/trl/trainer/grpo_trainer.py
+++ b/trl/trainer/grpo_trainer.py
@@ -462,7 +462,7 @@ def reward_func(completions, **kwargs):
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
@@ -534,9 +534,9 @@ def __init__(
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
- raise ValueError(
+ warnings.warn(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
- "This argument can only be used when the `model` argument is a string."
+ "The `model_init_kwargs` will be ignored."
)
# Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py
index 9e47ccf52f3..b36256ac64e 100644
--- a/trl/trainer/sft_config.py
+++ b/trl/trainer/sft_config.py
@@ -49,7 +49,8 @@ class SFTConfig(TrainingArguments):
Name of the column that contains text data in the dataset.
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
- `skip_prepare_dataset`.
+ `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True`
+ regardless of the provided value, since preprocessing is done on the fly.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
eos_token (`str` or `None`, *optional*, defaults to `None`):
@@ -159,7 +160,9 @@ class SFTConfig(TrainingArguments):
default=None,
metadata={
"help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is "
- "`skip_prepare_dataset`."
+ "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model "
+ "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, "
+ "since preprocessing is done on the fly."
},
)
dataset_num_proc: Optional[int] = field(
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 41dcae3f1fc..e1d53f1d7e7 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -23,11 +23,12 @@
import torch
import torch.nn as nn
+import transformers
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
+ AutoConfig,
+ AutoProcessor,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
@@ -259,6 +260,130 @@ def _convert_seq_lengths_to_position_ids(batch_seq_lengths: list[list[int]]) ->
return list(position_ids.split(example_lengths))
+@dataclass
+class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
+ """
+ Data collator for vision-language modeling tasks.
+
+ Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching,
+ vision-language data processing involves converting images into pixel values. This conversion is disk-intensive,
+ making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and
+ image processing on-the-fly to efficiently prepare batches.
+
+ Each input example should be a dictionary containing at least:
+ - An `"images"` key holding the image data.
+ - Either a `"messages"` key for conversational inputs or a `"text"` key for standard text inputs.
+
+ The collator outputs a dictionary including:
+ - `"input_ids"`: Tensor of token IDs.
+ - `"attention_mask"`: Tensor indicating attention mask.
+ - `"pixel_values"`: Tensor representing image pixel values.
+ - `"labels"`: Tensor for training labels.
+
+ Additional keys may be present depending on the processor, such as `"image_grid_thw"`.
+
+ Args:
+ processor (`ProcessorMixin`):
+ The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin`
+ and include a `tokenizer` with a defined `pad_token_id`.
+ max_length (`int` or `None`, optional, defaults to `None`):
+ Maximum sequence length for input tokens. If `None`, no truncation is applied.
+ pad_to_multiple_of (`int` or `None`, optional, defaults to `None`):
+ If set, the sequences will be padded to a multiple of this value.
+ dataset_text_field (`str`, optional, defaults to `"text"`):
+ Name of the column that contains text data in the dataset. This parameter is only relevant for
+ [standard datasets format](dataset_formats#standard).
+ return_tensors (`str`, optional, defaults to `"pt"`):
+ The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported.
+
+ Example:
+ ```python
+ >>> from trl import DataCollatorForVisionLanguageModeling
+ >>> from transformers import AutoProcessor
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> collator = DataCollatorForVisionLanguageModeling(processor)
+ >>> examples = [
+ ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]},
+ ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}
+ ... ]
+ >>> collator(examples)
+ {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198,
+ 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374,
+ 419, 30, 151645, 198],
+ [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198,
+ 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419,
+ 2168, 13, 151645, 198]]),
+ 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
+ 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431],
+ [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412],
+ [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133],
+ ...,
+ [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555],
+ [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817],
+ [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]),
+ 'image_grid_thw': tensor([[1, 4, 4],
+ [1, 4, 4]]),
+ 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198,
+ 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374,
+ 419, 30, 151645, 198],
+ [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198,
+ 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419,
+ 2168, 13, 151645, 198]])}
+ ```
+ """
+
+ processor: ProcessorMixin
+ max_length: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ dataset_text_field: str = "text"
+ return_tensors: str = "pt"
+
+ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
+ images = [example["images"] for example in examples]
+
+ if "messages" in examples[0]: # conversational case
+ for example in examples:
+ image_included = False
+ for message in example["messages"]:
+ if message["role"] == "user":
+ if isinstance(message["content"], str) and not image_included:
+ message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}]
+ image_included = True
+ elif isinstance(message["content"], str) and image_included:
+ message["content"] = [{"type": "text", "text": message["content"]}]
+ if message["role"] == "assistant":
+ if isinstance(message["content"], str):
+ message["content"] = [{"type": "text", "text": message["content"]}]
+ messages = [example["messages"] for example in examples]
+ texts = self.processor.apply_chat_template(messages, images=images)
+ elif self.dataset_text_field in examples[0]: # standard case
+ texts = [example[self.dataset_text_field] for example in examples]
+ else:
+ raise KeyError(
+ "The input examples must contain either 'messages' for conversational data or 'text' for standard "
+ "data."
+ )
+
+ output = self.processor(
+ images=images,
+ text=texts,
+ padding=True,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ truncation=self.max_length is not None,
+ max_length=self.max_length,
+ return_tensors=self.return_tensors,
+ add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
+ )
+ labels = output["input_ids"].clone()
+ labels[labels == self.processor.tokenizer.pad_token_id] = -100
+ # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in
+ # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision
+ # token definitions vary across architectures.
+ output["labels"] = labels
+ return output
+
+
class SFTTrainer(Trainer):
"""
Trainer for Supervised Fine-Tuning (SFT) method.
@@ -303,9 +428,10 @@ class SFTTrainer(Trainer):
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`):
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`):
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
- with [`~transformers.AutoTokenizer.from_pretrained`].
+ with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
+ If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
@@ -343,9 +469,7 @@ def __init__(
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
+ processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
@@ -356,9 +480,9 @@ def __init__(
formatting_func: Optional[Callable[[dict], str]] = None,
):
# Args
- model_id = model if isinstance(model, str) else model.config._name_or_path
if args is None:
- model_name = model_id.split("/")[-1]
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
args = SFTConfig(f"{model_name}-SFT")
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
dict_args = args.to_dict()
@@ -366,29 +490,56 @@ def __init__(
dict_args.pop("push_to_hub_token")
args = SFTConfig(**dict_args)
- # Handle the tokenizer
+ # Model
+ model_init_kwargs = args.model_init_kwargs or {}
+ if isinstance(model, str):
+ model_id = model
+ torch_dtype = model_init_kwargs.get("torch_dtype")
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
+ pass # torch_dtype is already a torch.dtype or "auto" or None
+ elif isinstance(torch_dtype, str) and torch_dtype in ["bfloat16", "float16", "float32"]:
+ torch_dtype = getattr(torch, torch_dtype)
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ else:
+ raise ValueError(
+ "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
+ f"a valid `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
+ )
+ config = AutoConfig.from_pretrained(model_id)
+ architecture = getattr(transformers, config.architectures[0])
+ model = architecture.from_pretrained(model_id, **model_init_kwargs)
+ else:
+ model_id = model.config._name_or_path
+ if args.model_init_kwargs is not None:
+ warnings.warn(
+ "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. "
+ "The `model_init_kwargs` will be ignored."
+ )
+
+ # Processing class
if processing_class is None:
- processing_class = AutoTokenizer.from_pretrained(model_id)
+ processing_class = AutoProcessor.from_pretrained(model_id)
+
+ # Handle pad token for processors or tokenizers
+ if isinstance(processing_class, ProcessorMixin):
+ tokenizer = processing_class.tokenizer
+ self._is_vlm = True
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
+ tokenizer = processing_class
+ self._is_vlm = False
+ else:
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
if args.eos_token is not None:
eos_token = args.eos_token
- eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
+ eos_token_id = tokenizer.convert_tokens_to_ids(eos_token)
if eos_token_id is None:
raise ValueError(
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
"in the vocabulary before using it as an EOS token."
)
- processing_class.eos_token_id = eos_token_id
-
- # Model
- if args.model_init_kwargs is not None and not isinstance(model, str):
- warnings.warn(
- "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
- "The `model_init_kwargs` will be ignored."
- )
- if isinstance(model, str):
- model = self._create_model_from_path(model, args)
+ tokenizer.eos_token_id = eos_token_id
if args.chat_template_path is not None:
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
@@ -402,6 +553,33 @@ def __init__(
else:
added_tokens = []
+ # Catch some wrong configurations related to VLMs
+ if self._is_vlm and args.packing:
+ raise ValueError(
+ "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig."
+ )
+ if self._is_vlm and args.padding_free:
+ raise ValueError(
+ "Padding-free training is yet not supported for vision-language models. Please set "
+ "`padding_free=False` in the `SFTConfig`."
+ )
+ if self._is_vlm and args.completion_only_loss:
+ raise ValueError(
+ "Completion-only loss is not yet supported for vision-language models. Please set "
+ "`completion_only_loss=False` in the `SFTConfig`."
+ )
+ if self._is_vlm and args.assistant_only_loss:
+ raise ValueError(
+ "Assistant-only loss is not yet supported for vision-language models. Please set "
+ "`assistant_only_loss=False` in the `SFTConfig`."
+ )
+ first_example = next(iter(train_dataset))
+ if self._is_vlm and "prompt" in first_example and "completion" in first_example:
+ raise ValueError(
+ "Prompt-completion datasets are not yet supported for vision-language models in `SFTTrainer`. "
+ "Please use a language-modeling type dataset instead."
+ )
+
# PEFT configuration and model wrapping
if peft_config is not None:
if added_tokens:
@@ -462,17 +640,19 @@ def __init__(
"to at least 2."
)
+ # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format
+ # is prompt-completion, and False if the dataset format is language modeling.
dataset_sample = next(iter(train_dataset))
if args.completion_only_loss is None:
self.completion_only_loss = "prompt" in dataset_sample
else:
self.completion_only_loss = args.completion_only_loss
- if data_collator is None:
+ if data_collator is None and not self._is_vlm:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
- pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
- pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
+ pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
+ pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
@@ -487,6 +667,13 @@ def __init__(
return_position_ids=use_flash_attention,
pad_to_multiple_of=args.pad_to_multiple_of,
)
+ elif data_collator is None and self._is_vlm:
+ data_collator = DataCollatorForVisionLanguageModeling(
+ processor=processing_class,
+ max_length=args.max_length,
+ pad_to_multiple_of=args.pad_to_multiple_of,
+ dataset_text_field=args.dataset_text_field,
+ )
if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
warnings.warn(
@@ -504,8 +691,12 @@ def __init__(
)
# Dataset
- preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
- if preprocess_dataset:
+ # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where
+ # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead.
+ skip_prepare_dataset = (
+ args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm
+ )
+ if not skip_prepare_dataset:
if self.completion_only_loss and formatting_func:
raise ValueError(
"A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
@@ -563,29 +754,6 @@ def __init__(
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
- def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
- """Creates a model from a path or model identifier."""
- model_init_kwargs = args.model_init_kwargs or {}
- # Handle torch dtype
- torch_dtype = model_init_kwargs.get("torch_dtype")
- if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
- pass # torch_dtype is already a torch.dtype or "auto" or None
- elif isinstance(torch_dtype, str): # it's a str, but not "auto"
- torch_dtype = getattr(torch, torch_dtype)
- model_init_kwargs["torch_dtype"] = torch_dtype
- else:
- raise ValueError(
- "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
- f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
- )
- # Disable caching if gradient checkpointing is enabled (not supported)
- # if args.gradient_checkpointing:
- # model_init_kwargs["use_cache"] = False
-
- # Create model
- model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
- return model
-
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
@@ -770,13 +938,10 @@ def _set_signature_columns_if_needed(self):
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
- self._signature_columns = [
- "input_ids",
- "labels",
- "seq_lengths",
- "completion_mask",
- "assistant_masks",
- ]
+ if self._is_vlm:
+ self._signature_columns = ["messages", "images"]
+ else:
+ self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"]
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""