diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 160e7372126..4d9a4f97cbe 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -55,8 +55,6 @@ title: Detoxifying a Language Model - local: multi_adapter_rl title: Multi Adapter RLHF - - local: training_vlm_sft - title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) title: Examples - sections: - sections: # Sorted alphabetically diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 5d0c0089b6e..40987bbee2b 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -1033,7 +1033,7 @@ dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions ## Vision datasets -Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently. +Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently. A conversational vision dataset differs from a standard conversational dataset in two key ways: @@ -1061,4 +1061,3 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h width="100%" height="560px" > - diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index 39c79cda2ae..b2ceef53bd6 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -13,7 +13,7 @@ This post-training method was contributed by [Younes Belkada](https://huggingfac This example demonstrates how to train a language model using the [`SFTTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara), a compact, diverse multi-turn dataset to benchmark reasoning and generalization. ```python -from trl import SFTTrainer, SFTConfig +from trl import SFTConfig, SFTTrainer from datasets import load_dataset trainer = SFTTrainer( @@ -91,7 +91,7 @@ This section breaks down how SFT works in practice, covering the key steps: **pr ### Preprocessing and tokenization During training, each example is expected to contain a **text field** or a **(prompt, completion)** pair, depending on the dataset format. For more details on the expected formats, see [Dataset formats](dataset_formats). -The `SFTTrainer` tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization. +The [`SFTTrainer`] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization. ### Computing the loss @@ -241,7 +241,7 @@ Unsloth is an open‑source framework for fine‑tuning and reinforcement learni This example shows how to transform the [Qwen 3 0.6B Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) model into an instruction-following model using the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara) and a chat template from [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B). The SFT Trainer automatically handles tokenizer updates and special token configuration. ```python -from trl import SFTTrainer, SFTConfig +from trl import SFTConfig, SFTTrainer from datasets import load_dataset trainer = SFTTrainer( @@ -280,122 +280,41 @@ Alternatively, use the structured conversation format (recommended): ## Tool Calling with SFT -The SFT trainer fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: +The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: * The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages) * The list of available tools in the `tools` column, typically provided as JSON schemas For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section. -## Extending `SFTTrainer` for Vision Language Models +## Training Vision Language Models -`SFTTrainer` does not yet inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset. - -### Preparing the Data - -The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images: - -```python -images = ["obama.png"] -messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Who is this?"}, - {"type": "image"} - ] - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Barack Obama"} - ] - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "What is he famous for?"} - ] - }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": "He is the 44th President of the United States."} - ] - } -] -``` - -To illustrate how this data format will be processed using the LLaVA model, you can use the following code: - -```python -from transformers import AutoProcessor - -processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") -print(processor.apply_chat_template(messages, tokenize=False)) -``` - -The output will be formatted as follows: - -```txt -Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States. -``` - - - -### A custom collator for processing multi-modal data - -Unlike the default behavior of [`SFTTrainer`], processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator: - -```python -def collate_fn(examples): - # Get the texts and images, and apply the chat template - texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] - images = [example["images"][0] for example in examples] - - # Tokenize the texts and process the images - batch = processor(images=images, text=texts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 - batch["labels"] = labels - - return batch -``` - -We can verify that the collator works as expected by running the following code: +[`SFTTrainer`] fully supports training Vision-Language Models (VLMs). To train a VLM, you need to provide a dataset with an additional `images` column containing the images to be processed. For more information on the expected dataset structure, see the [Dataset Format — Vision Dataset](dataset_formats#vision-dataset) section. +An example of such a dataset is the [LLaVA Instruct Mix](https://huggingface.co/datasets/trl-lib/llava-instruct-mix). ```python +from trl import SFTConfig, SFTTrainer from datasets import load_dataset -dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train") -examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example -collated_data = collate_fn(examples) -print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels']) +trainer = SFTTrainer( + model="Qwen/Qwen2.5-VL-3B-Instruct", + args=SFTConfig(max_length=None), + train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"), +) +trainer.train() ``` -### Training the vision-language model + -Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the [`SFTConfig`], specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`. +For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens. ```python -training_args.remove_unused_columns = False -training_args.dataset_kwargs = {"skip_prepare_dataset": True} - -trainer = SFTTrainer( - model=model, - args=training_args, - data_collator=collate_fn, - train_dataset=train_dataset, - processing_class=processor, -) +SFTConfig(max_length=None, ...) ``` -A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py). +Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset. -* [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s) -* [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf) + ## SFTTrainer @@ -407,3 +326,11 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs ## SFTConfig [[autodoc]] SFTConfig + +## DataCollatorForLanguageModeling + +[[autodoc]] trainer.sft_trainer.DataCollatorForLanguageModeling + +## DataCollatorForVisionLanguageModeling + +[[autodoc]] trainer.sft_trainer.DataCollatorForVisionLanguageModeling diff --git a/docs/source/training_vlm_sft.md b/docs/source/training_vlm_sft.md deleted file mode 100644 index a5c853614f9..00000000000 --- a/docs/source/training_vlm_sft.md +++ /dev/null @@ -1,380 +0,0 @@ -# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) - -![VLM SFT training procedure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure.png) - -## Overview - -This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases: - -- **Single Image + Text** -- **Multi-Image + Text** - -This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly. - -We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets. - -## Understanding the Datasets - -To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task. - -### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text) - -This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input. - -The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**. - - - -### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text) - -The **FanqingM/MMIU-Benchmark** dataset consists of: - -- **Context:** Included in the system prompt. -- **Question:** Provided as part of the user's input. -- **Series of Images:** Multiple images related to the question. -- **Answer:** The model's expected response. - -This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs. - - - -## Developing a Fine-Tuning Script for Multimodal SFT - -In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases. - -### Setting Up the Environment - -Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment: - -```bash -# Install the required libraries. Further details: https://huggingface.co/docs/trl/installation -pip install -U -q trl bitsandbytes peft hf_xet tensorboard -``` - -Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required. - -If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it. - -To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account. - -```bash -huggingface-cli login -``` - -### **Loading the Data** - -As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent. - -This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario. - -#### **Single Image + Text** - -![Single Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_single_image.png) - -In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`. - -```python -from datasets import load_dataset - -dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft" - -# Load Dataset -dataset = load_dataset(dataset_name) -``` - -#### **Multi-Image + Text (or Interleaving)** - -![Multi-Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_multi_image.png) - -Gemma 3 also supports **Multi-Image + Text** scenarios, where: - -- The model receives a **list of images** alongside a user message. -- The model processes **interleaved images and text** within a conversation. - -For this dataset, some preprocessing is required before training. - -```python -from datasets import load_dataset - -dataset_name = "FanqingM/MMIU-Benchmark" - -# Load Dataset -dataset = load_dataset(dataset_name) -``` - -After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look: - -```python -{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, -{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]}, -{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, -``` - -Here, `images_list` is a list of images: - -```python -images_list = [ - {"type": "image", "image": }, - {"type": "image", "image": }, - {"type": "image", "image": }, - {"type": "image", "image": }, - {"type": "image", "image": }, -] -``` - -This structure can be translated into code like this: - -```python -import os -import zipfile -import io -from datasets import DatasetDict -from huggingface_hub import hf_hub_download, list_repo_files -from PIL import Image - -dataset_train_split = "test" - -def format_data(samples: dict[str, any]) -> dict[str, list]: - formatted_samples = {"messages": []} - for cont in range(len(samples["question"])): - images = [] - for img_path in samples["input_image_path"][cont]: - try: - with open(img_path, "rb") as f: - img_bytes = f.read() - image = Image.open(io.BytesIO(img_bytes)).convert("RGB") - images.append({"type": "image", "image": image}) - except Exception as e: - print(f"Error processing image {img_path}: {e}") - continue - - formatted_samples["messages"].append( - [ - {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, - {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, - {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, - ] - ) - return formatted_samples - -# For multi-image example -def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: - all_files = list_repo_files(dataset_name, repo_type="dataset") - zip_files = [f for f in all_files if f.endswith(".zip")] - - for zip_filename in zip_files: - zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") - extract_folder = zip_filename.replace(".zip", "") - os.makedirs(extract_folder, exist_ok=True) - - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(extract_folder) - - dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) - return dataset - -dataset = prepare_dataset(dataset, dataset_name, dataset_train_split) -``` - -With this, your **Multi-Image + Text** dataset is now prepared for training. - -### **Preparing for Training** - -We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models. - -To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model. - -```python -import torch -from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig - -model_id = "google/gemma-3-4b-it" - -# BitsAndBytesConfig int-4 config -bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_quant_storage=torch.bfloat16, -) - -# Load model and tokenizer -model = AutoModelForImageTextToText.from_pretrained( - model_id, - device_map="auto", - torch_dtype=torch.bfloat16, - attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934) - quantization_config=bnb_config -) -processor = AutoProcessor.from_pretrained(model_id) -processor.tokenizer.padding_side = "right" -``` - -Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs). - -```python -from peft import LoraConfig, get_peft_model - -# Configure QLoRA -peft_config = LoraConfig( - lora_alpha=16, - lora_dropout=0.05, - r=16, - bias="none", - target_modules="all-linear", - task_type="CAUSAL_LM", - modules_to_save=[ - "lm_head", - "embed_tokens", - ], -) -``` - -With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs. - -```python -from trl import SFTConfig - -training_args = SFTConfig( - output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets). - num_train_epochs=1, # Set the number of epochs to train the model. - per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1 - gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1 - gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training. - optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance. - save_strategy="epoch", # Save checkpoints at the end of each epoch. - learning_rate=2e-05, # Learning rate for training. - bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations. - push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training. - report_to="tensorboard", # Automatically report metrics to tensorboard. - gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues. - dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually. - remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing). -) -``` - -The `collate_fn` is responsible for processing and preparing individual examples to form a batch. - -Each example in the batch undergoes the following steps: - -1. The **chat template** is applied to the text. -2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors. -3. The **labels** for training are set as the `input_ids` of the example. -4. Certain **special tokens** are **masked (ignored)** during loss computation: - - `pad_token_id` - - `` - - `` (corresponding to ID `262144`) - -This process is similar across different dataset types, with a minor variation in how images are handled: - -- **Single Image + Text** → A **list of images** is directly processed. -- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images. - -```python -from PIL import Image - -# For multi-image cases -def process_vision_info(messages: list[dict]) -> list[Image.Image]: - image_inputs = [] - for msg in messages: - content = msg.get("content", []) - if not isinstance(content, list): - content = [content] - - for element in content: - if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): - if "image" in element: - image = element["image"] - else: - image = element - if image is not None: - image = Image.open(io.BytesIO(image["bytes"])) - image_inputs.append(image.convert("RGB")) - return image_inputs - -def collate_fn(examples): - texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples] - if "images" in examples[0]: # single-image - images = [ - [img.convert("RGB") for img in example["images"]] - for example in examples - ] - else: # multi-image - images = [process_vision_info(example["messages"]) for example in examples] - - # Tokenize the texts and process the images - batch = processor( - images=images, text=texts, return_tensors="pt", padding=True - ) # Encode texts and images into tensors - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() # Clone input IDs for labels - # Mask image tokens - image_token_id = [ - processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) - ] - # Mask tokens for not being used in the loss computation - labels[labels == processor.tokenizer.pad_token_id] = -100 - labels[labels == image_token_id] = -100 - labels[labels == 262144] = -100 - - batch["labels"] = labels - return batch # Return the prepared batch -``` - -### **Training the Model** - -With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process. - -``` python -# Training -from trl import SFTTrainer - -trainer = SFTTrainer( - model=model, - args=training_args, - data_collator=collate_fn, - train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"], - processing_class=processor, - peft_config=peft_config, -) - -trainer.train() - -# Save the final model -trainer.save_model() -``` - -We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration. - - -### Results - -During and after training, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example: - -* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft) - -* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark) - -## Limitations - -Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results. - -## References - -For further reading and complementary resources, check out the following: - -- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) -- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) - diff --git a/examples/datasets/llava_instruct_mix.py b/examples/datasets/llava_instruct_mix.py new file mode 100644 index 00000000000..a819ae38b5a --- /dev/null +++ b/examples/datasets/llava_instruct_mix.py @@ -0,0 +1,107 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/llava-instruct-mix"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/llava-instruct-mix", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def process_example(example): + messages = [] + for message in ast.literal_eval(example["conversations"]): + content = message["value"] + content = content.replace("", "").strip() + role = "user" if message["from"] == "human" else "assistant" + messages.append({"role": role, "content": content}) + return {"messages": messages, "images": [example["image"]]} + + +def filter_long_examples(example): + total_length = sum(len(msg["content"]) for msg in example["messages"]) + return total_length <= 1000 + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# LLaVA Instruct Mix + +## Summary + +The LLaVA Instruct Mix dataset is a processed version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). + +## Data Structure + +- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) +- **Type**: [Language-modeling](https://huggingface.co/docs/trl/main/dataset_formats#language-modeling) + +Columns: +- `"images"`: The image associated with the text. +- `"messages"`: A list of messages in the conversation. + +This structure allows models to learn from the context of the conversation, enhancing their understanding of how to generate descriptive text based on visual inputs. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/llava_instruct_mix.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("theblackcat102/llava-instruct-mix") + + dataset = dataset.map( + process_example, remove_columns=["conversations", "image"], num_proc=script_args.dataset_num_proc + ) + dataset = dataset.filter(filter_long_examples, num_proc=script_args.dataset_num_proc) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id, num_proc=script_args.dataset_num_proc) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index 0cd926fe117..aca691d5a5d 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -178,7 +178,6 @@ class CustomScriptArguments(ScriptArguments): # Configure training args training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} # Load dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 53618babc27..dd961e48b04 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -23,27 +23,37 @@ pip install pillow # Tested on 8x H100 GPUs -accelerate launch - --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/sft_vlm.py \ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ - --per_device_train_batch_size 8 \ --gradient_accumulation_steps 8 \ - --output_dir sft-llava-1.5-7b-hf \ - --torch_dtype bfloat16 \ - --gradient_checkpointing + --output_dir LLaVA-1.5-7B-SFT \ + --torch_dtype bfloat16 For LLaVA-NeXT, use: (requires transformers>=4.45) --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1) --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --output_dir SmolVLM-SFT \ + --torch_dtype bfloat16 \ + --use_peft \ + --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj """ import torch from datasets import load_dataset -from transformers import AutoModelForImageTextToText, AutoProcessor, LlavaForConditionalGeneration +from transformers import AutoModelForImageTextToText from trl import ( ModelConfig, @@ -61,8 +71,7 @@ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) - training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} + training_args.max_length = None ################ # Model, Tokenizer & Processor @@ -78,38 +87,11 @@ device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) - ################ - # Create a data collator to encode text and image pairs - ################ - def collate_fn(examples): - # Get the texts and images, and apply the chat template - texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] - images = [example["images"] for example in examples] - if isinstance(model, LlavaForConditionalGeneration): - # LLava1.5 does not support multiple images - images = [image[0] for image in images] - - # Tokenize the texts and process the images - batch = processor(images=images, text=texts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) - labels[labels == image_token_id] = -100 - batch["labels"] = labels - - return batch - ################ # Dataset ################ @@ -121,10 +103,8 @@ def collate_fn(examples): trainer = SFTTrainer( model=model, args=training_args, - data_collator=collate_fn, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=processor, peft_config=get_peft_config(model_args), ) @@ -134,5 +114,3 @@ def collate_fn(examples): trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) - if trainer.accelerator.is_main_process: - processor.push_to_hub(training_args.hub_model_id) diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py index fcda0c73863..6cb94bee9fa 100644 --- a/examples/scripts/sft_vlm_gemma3.py +++ b/examples/scripts/sft_vlm_gemma3.py @@ -20,7 +20,7 @@ # /// """ -Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image). +Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image). accelerate launch \ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ @@ -28,14 +28,13 @@ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \ + --output_dir Gemma-3-4B-SFT-MMIU \ --torch_dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear \ --attn_implementation eager -Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image). +Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image). accelerate launch \ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ @@ -44,8 +43,7 @@ --dataset_train_split test \ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \ + --output_dir Gemma-3-4B-SFT-MMIU \ --torch_dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear @@ -60,7 +58,7 @@ from datasets import DatasetDict, load_dataset from huggingface_hub import hf_hub_download, list_repo_files from PIL import Image -from transformers import AutoModelForImageTextToText, AutoProcessor +from transformers import AutoModelForImageTextToText from trl import ( ModelConfig, @@ -119,7 +117,7 @@ def format_data(samples: dict[str, any]) -> dict[str, list]: # For multi-image example -def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: +def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict: all_files = list_repo_files(dataset_name, repo_type="dataset") zip_files = [f for f in all_files if f.endswith(".zip")] @@ -139,8 +137,7 @@ def main(): parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) - training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} + training_args.max_length = None ################ # Model, Tokenizer & Processor @@ -156,50 +153,16 @@ def main(): device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) - processor.tokenizer.padding_side = "right" - model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) - def collate_fn(examples): - texts = [ - processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() - for example in examples - ] - if "images" in examples[0]: # single-image - images = [[img.convert("RGB") for img in example["images"]] for example in examples] - else: # multi-image - images = [process_vision_info(example["messages"]) for example in examples] - - # Tokenize the texts and process the images - batch = processor( - images=images, text=texts, return_tensors="pt", padding=True - ) # Encode texts and images into tensors - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() # Clone input IDs for labels - # Mask image tokens - image_token_id = [ - processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) - ] - # Mask tokens for not being used in the loss computation - labels[labels == processor.tokenizer.pad_token_id] = -100 - labels[labels == image_token_id] = -100 - labels[labels == 262144] = -100 - - batch["labels"] = labels - return batch # Return the prepared batch - ################ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if script_args.dataset_name == "FanqingM/MMIU-Benchmark": - dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split) + dataset = prepare_dataset(dataset, script_args.dataset_name) ################ # Training @@ -207,10 +170,8 @@ def collate_fn(examples): trainer = SFTTrainer( model=model, args=training_args, - data_collator=collate_fn, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=processor, peft_config=get_peft_config(model_args), ) @@ -220,8 +181,6 @@ def collate_fn(examples): trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) - if trainer.accelerator.is_main_process: - processor.push_to_hub(training_args.hub_model_id) if __name__ == "__main__": diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py deleted file mode 100644 index d6ff2da1974..00000000000 --- a/examples/scripts/sft_vlm_smol_vlm.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# /// script -# dependencies = [ -# "trl @ git+https://github.com/huggingface/trl.git", -# "Pillow>=9.4.0", -# ] -# /// - -""" -pip install pillow - -# Tested on 8x H100 GPUs -accelerate launch - --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ - sft_vlm_smol_vlm.py \ - --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ - --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --output_dir sft-smol-vlm-hf \ - --torch_dtype bfloat16 \ - --gradient_checkpointing \ - --use_peft \ - --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj - -For LLaVA-NeXT, use: (requires transformers>=4.45) - --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf - -For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1) - --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct -""" - -import torch -from datasets import load_dataset -from transformers import ( - AutoModelForImageTextToText, - AutoProcessor, - Idefics3ForConditionalGeneration, - LlavaForConditionalGeneration, -) - -from trl import ( - ModelConfig, - ScriptArguments, - SFTConfig, - SFTTrainer, - TrlParser, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) - - -if __name__ == "__main__": - parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) - script_args, training_args, model_args = parser.parse_args_and_config() - training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) - training_args.remove_unused_columns = False - training_args.dataset_kwargs = {"skip_prepare_dataset": True} - - ################ - # Model, Tokenizer & Processor - ################ - torch_dtype = ( - model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code - ) - - model = AutoModelForImageTextToText.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs - ) - - ################ - # Create a data collator to encode text and image pairs - ################ - def collate_fn(examples): - # Get the texts and images, and apply the chat template - texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] - images = [example["images"] for example in examples] - if isinstance(model, LlavaForConditionalGeneration): - # LLava1.5 does not support multiple images - images = [image[0] for image in images] - - # Tokenize the texts and process the images - batch = processor(images=images, text=texts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - if isinstance(model, Idefics3ForConditionalGeneration): - image_token_id = processor.tokenizer.additional_special_tokens_ids[ - processor.tokenizer.additional_special_tokens.index("") - ] - else: - image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) - labels[labels == image_token_id] = -100 - batch["labels"] = labels - - return batch - - ################ - # Dataset - ################ - dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) - - ################ - # Training - ################ - trainer = SFTTrainer( - model=model, - args=training_args, - data_collator=collate_fn, - train_dataset=dataset[script_args.dataset_train_split], - eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=processor, - peft_config=get_peft_config(model_args), - ) - - trainer.train() - - # Save and push to hub - trainer.save_model(training_args.output_dir) - if training_args.push_to_hub: - trainer.push_to_hub(dataset_name=script_args.dataset_name) - if trainer.accelerator.is_main_process: - processor.push_to_hub(training_args.hub_model_id) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 27f51d6293f..55b5e7105c9 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -16,9 +16,11 @@ # `trl-internal-testing` organization. # This script is meant to be run when adding new tiny model to the TRL library. +import torch from huggingface_hub import HfApi, ModelCard from torch import nn from transformers import ( + AutoConfig, AutoProcessor, AutoTokenizer, BartConfig, @@ -35,7 +37,6 @@ FalconMambaForCausalLM, Gemma2Config, Gemma2ForCausalLM, - Gemma3Config, Gemma3ForConditionalGeneration, GemmaConfig, GemmaForCausalLM, @@ -47,18 +48,17 @@ GptOssForCausalLM, Idefics2Config, Idefics2ForConditionalGeneration, + Idefics3ForConditionalGeneration, + InternVLForConditionalGeneration, LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, - LlavaConfig, LlavaForConditionalGeneration, - LlavaNextConfig, LlavaNextForConditionalGeneration, MistralConfig, MistralForCausalLM, OPTConfig, OPTForCausalLM, - PaliGemmaConfig, PaliGemmaForConditionalGeneration, Phi3Config, Phi3ForCausalLM, @@ -74,7 +74,6 @@ Qwen3ForSequenceClassification, Qwen3MoeConfig, Qwen3MoeForCausalLM, - SmolVLMConfig, SmolVLMForConditionalGeneration, T5Config, T5ForConditionalGeneration, @@ -98,7 +97,7 @@ api = HfApi() -def push_to_hub(model, tokenizer, prefix=None, suffix=None): +def push_to_hub(model, tokenizer, prefix=None, suffix=None, force=False): model_class_name = model.__class__.__name__ content = MODEL_CARD.format(model_class_name=model_class_name) model_card = ModelCard(content) @@ -108,7 +107,7 @@ def push_to_hub(model, tokenizer, prefix=None, suffix=None): if suffix is not None: repo_id += f"-{suffix}" - if api.repo_exists(repo_id): + if api.repo_exists(repo_id) and not force: print(f"Model {repo_id} already exists, skipping") else: model.push_to_hub(repo_id) @@ -179,7 +178,7 @@ def init_weights_tiny_model(model): revision = "refs/pr/14" if model_id == "Qwen/Qwen3-8B" else "main" # chat template with {% generation %} tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) config = config_class( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), hidden_size=8, num_attention_heads=4, num_key_value_heads=2, @@ -197,7 +196,7 @@ def init_weights_tiny_model(model): ]: tokenizer = AutoTokenizer.from_pretrained(model_id) config = config_class( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), hidden_size=8, num_attention_heads=4, num_key_value_heads=2, @@ -214,7 +213,7 @@ def init_weights_tiny_model(model): # Two slightly bigger models, required for vLLM testing tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct") config = Qwen2Config( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM num_attention_heads=4, num_key_value_heads=2, @@ -226,7 +225,7 @@ def init_weights_tiny_model(model): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") config = Qwen3Config( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM num_attention_heads=4, num_key_value_heads=2, @@ -244,7 +243,7 @@ def init_weights_tiny_model(model): ]: tokenizer = AutoTokenizer.from_pretrained(model_id) config = config_class( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), hidden_size=8, num_attention_heads=4, num_key_value_heads=2, @@ -263,7 +262,7 @@ def init_weights_tiny_model(model): ]: tokenizer = AutoTokenizer.from_pretrained(model_id) config = config_class( - vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + vocab_size=len(tokenizer.vocab), d_model=16, encoder_layers=2, decoder_layers=2, @@ -279,51 +278,43 @@ def init_weights_tiny_model(model): # Vision Language Models -for model_id, config_class, model_class in [ - ("google/gemma-3-4b-it", Gemma3Config, Gemma3ForConditionalGeneration), - ("google/paligemma-3b-pt-224", PaliGemmaConfig, PaliGemmaForConditionalGeneration), - ("HuggingFaceM4/idefics2-8b", Idefics2Config, Idefics2ForConditionalGeneration), - ("HuggingFaceTB/SmolVLM2-2.2B-Instruct", SmolVLMConfig, SmolVLMForConditionalGeneration), - ("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlavaForConditionalGeneration), - ("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, LlavaNextForConditionalGeneration), - ("Qwen/Qwen2-VL-2B-Instruct", Qwen2VLConfig, Qwen2VLForConditionalGeneration), - ("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration), +for model_id, model_class in [ + ("google/gemma-3-4b-it", Gemma3ForConditionalGeneration), + ("google/paligemma-3b-pt-224", PaliGemmaForConditionalGeneration), + ("HuggingFaceM4/idefics2-8b", Idefics2ForConditionalGeneration), + ("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3ForConditionalGeneration), + ("HuggingFaceTB/SmolVLM2-2.2B-Instruct", SmolVLMForConditionalGeneration), + ("llava-hf/llava-1.5-7b-hf", LlavaForConditionalGeneration), + ("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextForConditionalGeneration), + ("OpenGVLab/InternVL3-8B-hf", InternVLForConditionalGeneration), + ("Qwen/Qwen2-VL-2B-Instruct", Qwen2VLForConditionalGeneration), + ("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLForConditionalGeneration), ]: processor = AutoProcessor.from_pretrained(model_id) - kwargs = {} - text_kwargs = {} - vision_kwargs = {} - if config_class == PaliGemmaConfig: - kwargs["projection_dim"] = 8 - if config_class in [LlavaConfig, LlavaNextConfig, PaliGemmaConfig]: - vision_kwargs["projection_dim"] = 8 - if config_class in [LlavaConfig, LlavaNextConfig]: - vision_kwargs["image_size"] = 336 - vision_kwargs["patch_size"] = 14 - if config_class in [Qwen2VLConfig, Qwen2_5_VLConfig]: - kwargs["vision_start_token_id"] = 151652 - text_kwargs["rope_scaling"] = {"type": "mrope", "mrope_section": [1]} - vision_kwargs["depth"] = 4 - vision_kwargs["embed_dim"] = 64 + config = AutoConfig.from_pretrained(model_id) + + config.text_config.num_hidden_layers = 2 + config.text_config.hidden_size = 16 + config.text_config.num_attention_heads = 4 + config.text_config.num_key_value_heads = 2 + + config.vision_config.num_hidden_layers = 2 + config.vision_config.hidden_size = 16 + config.vision_config.num_attention_heads = 4 + config.vision_config.num_key_value_heads = 2 + + if isinstance(config, (Qwen2VLConfig)): + config.vision_config.depth = 2 + + if isinstance(config, (Qwen2VLConfig, Qwen2_5_VLConfig)): + config.text_config.rope_scaling["mrope_section"] = [2] + + if isinstance(config, (Qwen2_5_VLConfig)): + config.vision_config.out_hidden_size = 16 + + if isinstance(config, Idefics2Config): + config.perceiver_config.hidden_size = 16 + + model = model_class(config).to(dtype=torch.bfloat16) - config = config_class( - text_config=dict( - vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder), - hidden_size=8, - num_attention_heads=4, - num_key_value_heads=2, - num_hidden_layers=2, - intermediate_size=32, - **text_kwargs, - ), - vision_config=dict( - hidden_size=16, - num_attention_heads=4, - num_hidden_layers=2, - intermediate_size=32, - **vision_kwargs, - ), - **kwargs, - ) - model = model_class(config) push_to_hub(model, processor, "tiny") diff --git a/scripts/generate_zen_image_dataset.py b/scripts/generate_zen_image_dataset.py index 5795c8d8928..378ed16a5c3 100644 --- a/scripts/generate_zen_image_dataset.py +++ b/scripts/generate_zen_image_dataset.py @@ -15,7 +15,7 @@ from dataclasses import dataclass, field import numpy as np -from datasets import Dataset, Features, Image, Sequence, Value +from datasets import Dataset, Features, Image, List, Sequence, Value from transformers import HfArgumentParser @@ -75,9 +75,9 @@ def main(test_size, push_to_hub, repo_id): "If the implementation is easy to explain, it may be a good idea.", "Namespaces are one honking great idea -- let's do more of those!", ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_language_modeling_dataset = Dataset.from_dict(data, features=Features(text=Value("string"), image=Image())) + standard_language_modeling_dataset = Dataset.from_dict(data, features=Features(text=Value("string"), images=List(Image()))) standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_language_modeling_dataset.push_to_hub(repo_id, config_name="standard_language_modeling") @@ -105,9 +105,9 @@ def main(test_size, push_to_hub, repo_id): "If the implementation is easy", "Namespaces are one honking great", ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), image=Image())) + standard_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), images=List(Image()))) standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_prompt_only_dataset.push_to_hub(repo_id, config_name="standard_prompt_only") @@ -156,9 +156,9 @@ def main(test_size, push_to_hub, repo_id): " to explain, it may be a good idea.", " idea -- let's do more of those!", ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), image=Image())) + standard_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), images=List(Image()))) standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_prompt_completion_dataset.push_to_hub(repo_id, config_name="standard_prompt_completion") @@ -228,9 +228,9 @@ def main(test_size, push_to_hub, repo_id): " it's probably magic.", " watermelon -- let's plant some!", ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), chosen=Value("string"), rejected=Value("string"), image=Image())) + standard_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), chosen=Value("string"), rejected=Value("string"), images=List(Image()))) standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_preference_dataset.push_to_hub(repo_id, config_name="standard_preference") @@ -279,9 +279,9 @@ def main(test_size, push_to_hub, repo_id): "If the implementation is easy it's probably magic.", "Namespaces are one honking great watermelon -- let's plant some!", ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Value("string"), rejected=Value("string"), image=Image())) + standard_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Value("string"), rejected=Value("string"), images=List(Image()))) {'prompt': Value(dtype='string'), 'completions': Sequence(feature=Value(dtype='string'), length=-1), 'labels': Sequence(feature=Value(dtype='bool'), length=-1)} standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: @@ -332,9 +332,9 @@ def main(test_size, push_to_hub, repo_id): " watermelon -- let's plant some!", ], "label": [True, False, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, False, False], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), label=Value("bool"), image=Image())) + standard_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completion=Value("string"), label=Value("bool"), images=List(Image()))) standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference") @@ -403,9 +403,9 @@ def main(test_size, push_to_hub, repo_id): [True, True], [False] ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - standard_stepwise_supervision_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completions=Sequence(Value("string")), labels=Sequence(Value("bool")), image=Image())) + standard_stepwise_supervision_dataset = Dataset.from_dict(data, features=Features(prompt=Value("string"), completions=Sequence(Value("string")), labels=Sequence(Value("bool")), images=List(Image()))) standard_stepwise_supervision_dataset = standard_stepwise_supervision_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: standard_stepwise_supervision_dataset.push_to_hub(repo_id, config_name="standard_stepwise_supervision") @@ -433,9 +433,9 @@ def main(test_size, push_to_hub, repo_id): [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it may be a good idea."}], [{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}], ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_language_modeling_dataset = Dataset.from_dict(data, features=Features(messages=Message, image=Image())) + conversational_language_modeling_dataset = Dataset.from_dict(data, features=Features(messages=Message, images=List(Image()))) conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling") @@ -463,9 +463,9 @@ def main(test_size, push_to_hub, repo_id): [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}], [{"role": "user", "content": "Any great ideas?"}], ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Message, image=Image())) + conversational_prompt_only_dataset = Dataset.from_dict(data, features=Features(prompt=Message, images=List(Image()))) conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only") @@ -514,9 +514,9 @@ def main(test_size, push_to_hub, repo_id): [{"role": "assistant", "content": "It means it may be a good idea."}], [{"role": "assistant", "content": "Namespaces are one honking great idea."}], ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, image=Image())) + conversational_prompt_completion_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, images=List(Image()))) conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion") @@ -586,9 +586,9 @@ def main(test_size, push_to_hub, repo_id): [{"role": "assistant", "content": "It means it's a bad idea."}], [{"role": "assistant", "content": "Recursion."}], ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, chosen=Message, rejected=Message, image=Image())) + conversational_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, chosen=Message, rejected=Message, images=List(Image()))) conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference") @@ -637,9 +637,9 @@ def main(test_size, push_to_hub, repo_id): [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it's a bad idea."}], [{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Recursion."}], ], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Message, rejected=Message, image=Image())) + conversational_implicit_prompt_preference_dataset = Dataset.from_dict(data, features=Features(chosen=Message, rejected=Message, images=List(Image()))) conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="conversational_implicit_prompt_preference") @@ -689,9 +689,9 @@ def main(test_size, push_to_hub, repo_id): [{'role': 'assistant', 'content': 'Namespaces are one honking great idea.'}], ], "label": [True, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, True, True], - "image": [np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8) for h, w in sizes], + "images": [[np.random.uniform(low=0.0, high=255.0, size=(h, w, 3)).astype(np.uint8)] for h, w in sizes], } - conversational_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, label=Value("bool"), image=Image())) + conversational_unpaired_preference_dataset = Dataset.from_dict(data, features=Features(prompt=Message, completion=Message, label=Value("bool"), images=List(Image()))) conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False) if push_to_hub: conversational_unpaired_preference_dataset.push_to_hub(repo_id, config_name="conversational_unpaired_preference") diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 512d94ed6b1..bffb629840b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1417,7 +1417,7 @@ def test_train_with_iterable_dataset(self): class DPOVisionTrainerTester(TrlTestCase): @parameterized.expand( [ - # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), # device issue from transformers + # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 # ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",), ("trl-internal-testing/tiny-LlavaForConditionalGeneration",), ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",), diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index ead8203b463..1e34346fb1d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -14,17 +14,11 @@ import pathlib -import numpy as np +import pytest import torch -from datasets import Dataset, Image, Sequence, load_dataset +from datasets import Dataset, load_dataset from parameterized import parameterized -from transformers import ( - AutoModelForCausalLM, - AutoProcessor, - AutoTokenizer, - LlavaForConditionalGeneration, - is_vision_available, -) +from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.testing_utils import require_flash_attn, require_peft, require_vision from transformers.utils import is_peft_available @@ -37,9 +31,6 @@ if is_peft_available(): from peft import LoraConfig, PeftModel, get_peft_model -if is_vision_available(): - from PIL import Image as PILImage - def formatting_prompts_func(example): text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" @@ -277,50 +268,6 @@ def setUp(self): "trl-internal-testing/zen", "standard_prompt_completion" ) - if is_vision_available(): - self.dummy_vsft_instruction_dataset = Dataset.from_dict( - { - "messages": [ - [ - { - "role": "user", - "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "It is random noise."}], - }, - { - "role": "user", - "content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "2"}], - }, - ], - [ - { - "role": "user", - "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "It is random noise."}], - }, - ], - ], - "images": [ - [PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")], - [PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")], - ], - } - ) - self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image())) - self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column( - "images", Sequence(Image()) - ) - def test_uncorrect_data(self): # Shoud work as SFTTrainer natively supports conversational lm dataset training_args = SFTConfig( @@ -501,86 +448,6 @@ def test_no_packing(self): self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"])) self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) - @require_vision - def test_skip_prepare_dataset(self): - training_args = SFTConfig( - output_dir=self.tmp_dir, - per_device_train_batch_size=2, - remove_unused_columns=False, - dataset_kwargs={"skip_prepare_dataset": True}, - report_to="none", - ) - - trainer = SFTTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - args=training_args, - train_dataset=self.dummy_vsft_instruction_dataset, - ) - self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features) - - def test_skip_prepare_dataset_with_no_packing(self): - training_args = SFTConfig( - output_dir=self.tmp_dir, - per_device_train_batch_size=2, - remove_unused_columns=False, - packing=False, - dataset_kwargs={"skip_prepare_dataset": True}, - report_to="none", - ) - - trainer = SFTTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - args=training_args, - train_dataset=self.dummy_dataset, - ) - self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features) - - @require_vision - def test_llava(self): - training_args = SFTConfig( - output_dir=self.tmp_dir, - remove_unused_columns=False, - dataset_kwargs={"skip_prepare_dataset": True}, - report_to="none", - ) - tiny_llava = LlavaForConditionalGeneration.from_pretrained( - "trl-internal-testing/tiny-LlavaForConditionalGeneration" - ) - processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration") - - processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious -user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's -questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for -item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' -%}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if -add_generation_prompt %}ASSISTANT: {% endif %}""" - - def collate_fn(examples): - # Get the texts and images, and apply the chat template - texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] - images = [example["images"][0] for example in examples] - - # Tokenize the texts and process the images - batch = processor(images=images, text=texts, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 - batch["labels"] = labels - - return batch - - trainer = SFTTrainer( - model=tiny_llava, - args=training_args, - data_collator=collate_fn, - train_dataset=self.dummy_vsft_instruction_dataset, - ) - - trainer.train() - - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - # This new tester aims to replace the first one at some point class SFTTrainerTester2(TrlTestCase): @@ -1404,3 +1271,92 @@ def test_tag_added_peft(self): for tag in ["sft", "trl"]: self.assertIn(tag, trainer.model.model_tags) + + @parameterized.expand( + [ + ("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",), + # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 + # ("trl-internal-testing/tiny-Idefics3ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 + ("trl-internal-testing/tiny-LlavaForConditionalGeneration",), + ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",), + ("trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",), + ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), + # ("trl-internal-testing/tiny-SmolVLMForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 + ] + ) + @require_vision + def test_train_vlm(self, model_id): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + max_length=None, # For VLMs, truncating can remove image tokens, leading to errors + report_to="none", + ) + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # For some reason, these params are not updated. This is probably not related to TRL, but to + # the model itself. We should investigate this further, but for now we just skip these params. + # fmt: off + if ( + model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n + ): + # fmt: on + continue + self.assertFalse( + torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + ) + + # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. + # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. + @pytest.mark.slow + @require_vision + def test_train_vlm_gemma_3n(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig( + output_dir=self.tmp_dir, + max_length=None, + per_device_train_batch_size=1, + gradient_checkpointing=True, + model_init_kwargs={"torch_dtype": "bfloat16"}, + report_to="none", + ) + trainer = SFTTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "model.vision_tower.timm_model.conv_stem.bn.weight" in n: + # This parameter is not updated, not sure why at this point. + continue + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 23b334ecad8..12cc5443757 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1549,7 +1549,7 @@ def concatenated_forward( loss_mask = loss_mask[:, -logits_to_keep:] if logits.shape[:2] != labels.shape[:2]: - # for llava, the returned logits include the image tokens (placed before the text tokens) + # for LLaVA, the returned logits include the image tokens (placed before the text tokens) seq_len = labels.shape[1] logits = logits[:, -seq_len:] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b237867f913..108547fcd79 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -462,7 +462,7 @@ def reward_func(completions, **kwargs): and content). eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`): + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`): Processing class used to process the data. The padding side must be set to "left". If `None`, the processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, @@ -534,9 +534,9 @@ def __init__( else: model_id = model.config._name_or_path if args.model_init_kwargs is not None: - raise ValueError( + warnings.warn( "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " - "This argument can only be used when the `model` argument is a string." + "The `model_init_kwargs` will be ignored." ) # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 9e47ccf52f3..b36256ac64e 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -49,7 +49,8 @@ class SFTConfig(TrainingArguments): Name of the column that contains text data in the dataset. dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): Dictionary of optional keyword arguments for the dataset preparation. The only supported key is - `skip_prepare_dataset`. + `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` + regardless of the provided value, since preprocessing is done on the fly. dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. eos_token (`str` or `None`, *optional*, defaults to `None`): @@ -159,7 +160,9 @@ class SFTConfig(TrainingArguments): default=None, metadata={ "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " - "`skip_prepare_dataset`." + "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model " + "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, " + "since preprocessing is done on the fly." }, ) dataset_num_proc: Optional[int] = field( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 41dcae3f1fc..e1d53f1d7e7 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -23,11 +23,12 @@ import torch import torch.nn as nn +import transformers from accelerate import PartialState from datasets import Dataset, IterableDataset from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, + AutoConfig, + AutoProcessor, BaseImageProcessor, DataCollator, FeatureExtractionMixin, @@ -259,6 +260,130 @@ def _convert_seq_lengths_to_position_ids(batch_seq_lengths: list[list[int]]) -> return list(position_ids.split(example_lengths)) +@dataclass +class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + """ + Data collator for vision-language modeling tasks. + + Unlike text-only datasets—where the collator typically receives pre-tokenized inputs ready for batching, + vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, + making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and + image processing on-the-fly to efficiently prepare batches. + + Each input example should be a dictionary containing at least: + - An `"images"` key holding the image data. + - Either a `"messages"` key for conversational inputs or a `"text"` key for standard text inputs. + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs. + - `"attention_mask"`: Tensor indicating attention mask. + - `"pixel_values"`: Tensor representing image pixel values. + - `"labels"`: Tensor for training labels. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"`. + + Args: + processor (`ProcessorMixin`): + The processor used to tokenize text and process images. It must be a subclass of `ProcessorMixin` + and include a `tokenizer` with a defined `pad_token_id`. + max_length (`int` or `None`, optional, defaults to `None`): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + dataset_text_field (`str`, optional, defaults to `"text"`): + Name of the column that contains text data in the dataset. This parameter is only relevant for + [standard datasets format](dataset_formats#standard). + return_tensors (`str`, optional, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + + Example: + ```python + >>> from trl import DataCollatorForVisionLanguageModeling + >>> from transformers import AutoProcessor + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> collator = DataCollatorForVisionLanguageModeling(processor) + >>> examples = [ + ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, + ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]} + ... ] + >>> collator(examples) + {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], + [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], + [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], + ..., + [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], + [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], + [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), + 'image_grid_thw': tensor([[1, 4, 4], + [1, 4, 4]]), + 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]])} + ``` + """ + + processor: ProcessorMixin + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + dataset_text_field: str = "text" + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + images = [example["images"] for example in examples] + + if "messages" in examples[0]: # conversational case + for example in examples: + image_included = False + for message in example["messages"]: + if message["role"] == "user": + if isinstance(message["content"], str) and not image_included: + message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] + image_included = True + elif isinstance(message["content"], str) and image_included: + message["content"] = [{"type": "text", "text": message["content"]}] + if message["role"] == "assistant": + if isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + messages = [example["messages"] for example in examples] + texts = self.processor.apply_chat_template(messages, images=images) + elif self.dataset_text_field in examples[0]: # standard case + texts = [example[self.dataset_text_field] for example in examples] + else: + raise KeyError( + "The input examples must contain either 'messages' for conversational data or 'text' for standard " + "data." + ) + + output = self.processor( + images=images, + text=texts, + padding=True, + pad_to_multiple_of=self.pad_to_multiple_of, + truncation=self.max_length is not None, + max_length=self.max_length, + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + labels = output["input_ids"].clone() + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in + # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision + # token definitions vary across architectures. + output["labels"] = labels + return output + + class SFTTrainer(Trainer): """ Trainer for Supervised Fine-Tuning (SFT) method. @@ -303,9 +428,10 @@ class SFTTrainer(Trainer): The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*, defaults to `None`): + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`] or `None`, *optional*, defaults to `None`): Processing class used to process the data. If `None`, the processing class is loaded from the model's name - with [`~transformers.AutoTokenizer.from_pretrained`]. + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). @@ -343,9 +469,7 @@ def __init__( data_collator: Optional[DataCollator] = None, # type: ignore train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ] = None, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, callbacks: Optional[list[TrainerCallback]] = None, @@ -356,9 +480,9 @@ def __init__( formatting_func: Optional[Callable[[dict], str]] = None, ): # Args - model_id = model if isinstance(model, str) else model.config._name_or_path if args is None: - model_name = model_id.split("/")[-1] + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] args = SFTConfig(f"{model_name}-SFT") elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): dict_args = args.to_dict() @@ -366,29 +490,56 @@ def __init__( dict_args.pop("push_to_hub_token") args = SFTConfig(**dict_args) - # Handle the tokenizer + # Model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str) and torch_dtype in ["bfloat16", "float16", "float32"]: + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + warnings.warn( + "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class if processing_class is None: - processing_class = AutoTokenizer.from_pretrained(model_id) + processing_class = AutoProcessor.from_pretrained(model_id) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") if args.eos_token is not None: eos_token = args.eos_token - eos_token_id = processing_class.convert_tokens_to_ids(eos_token) + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) if eos_token_id is None: raise ValueError( f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " "in the vocabulary before using it as an EOS token." ) - processing_class.eos_token_id = eos_token_id - - # Model - if args.model_init_kwargs is not None and not isinstance(model, str): - warnings.warn( - "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " - "The `model_init_kwargs` will be ignored." - ) - if isinstance(model, str): - model = self._create_model_from_path(model, args) + tokenizer.eos_token_id = eos_token_id if args.chat_template_path is not None: if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): @@ -402,6 +553,33 @@ def __init__( else: added_tokens = [] + # Catch some wrong configurations related to VLMs + if self._is_vlm and args.packing: + raise ValueError( + "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." + ) + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `SFTConfig`." + ) + if self._is_vlm and args.completion_only_loss: + raise ValueError( + "Completion-only loss is not yet supported for vision-language models. Please set " + "`completion_only_loss=False` in the `SFTConfig`." + ) + if self._is_vlm and args.assistant_only_loss: + raise ValueError( + "Assistant-only loss is not yet supported for vision-language models. Please set " + "`assistant_only_loss=False` in the `SFTConfig`." + ) + first_example = next(iter(train_dataset)) + if self._is_vlm and "prompt" in first_example and "completion" in first_example: + raise ValueError( + "Prompt-completion datasets are not yet supported for vision-language models in `SFTTrainer`. " + "Please use a language-modeling type dataset instead." + ) + # PEFT configuration and model wrapping if peft_config is not None: if added_tokens: @@ -462,17 +640,19 @@ def __init__( "to at least 2." ) + # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format + # is prompt-completion, and False if the dataset format is language modeling. dataset_sample = next(iter(train_dataset)) if args.completion_only_loss is None: self.completion_only_loss = "prompt" in dataset_sample else: self.completion_only_loss = args.completion_only_loss - if data_collator is None: + if data_collator is None and not self._is_vlm: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token - pad_token_id = processing_class.convert_tokens_to_ids(pad_token) + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) if pad_token_id is None: raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " @@ -487,6 +667,13 @@ def __init__( return_position_ids=use_flash_attention, pad_to_multiple_of=args.pad_to_multiple_of, ) + elif data_collator is None and self._is_vlm: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + pad_to_multiple_of=args.pad_to_multiple_of, + dataset_text_field=args.dataset_text_field, + ) if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: warnings.warn( @@ -504,8 +691,12 @@ def __init__( ) # Dataset - preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) - if preprocess_dataset: + # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where + # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. + skip_prepare_dataset = ( + args.dataset_kwargs is not None and args.dataset_kwargs.get("skip_prepare_dataset", False) or self._is_vlm + ) + if not skip_prepare_dataset: if self.completion_only_loss and formatting_func: raise ValueError( "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " @@ -563,29 +754,6 @@ def __init__( if hasattr(self.model, "add_model_tags"): self.model.add_model_tags(self._tag_names) - def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel: - """Creates a model from a path or model identifier.""" - model_init_kwargs = args.model_init_kwargs or {} - # Handle torch dtype - torch_dtype = model_init_kwargs.get("torch_dtype") - if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: - pass # torch_dtype is already a torch.dtype or "auto" or None - elif isinstance(torch_dtype, str): # it's a str, but not "auto" - torch_dtype = getattr(torch, torch_dtype) - model_init_kwargs["torch_dtype"] = torch_dtype - else: - raise ValueError( - "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " - f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." - ) - # Disable caching if gradient checkpointing is enabled (not supported) - # if args.gradient_checkpointing: - # model_init_kwargs["use_cache"] = False - - # Create model - model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) - return model - def _prepare_dataset( self, dataset: Union[Dataset, IterableDataset], @@ -770,13 +938,10 @@ def _set_signature_columns_if_needed(self): # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the # dataset. So we need to override the default signature columns to include "completion_mask" as well. if self._signature_columns is None: - self._signature_columns = [ - "input_ids", - "labels", - "seq_lengths", - "completion_mask", - "assistant_masks", - ] + if self._is_vlm: + self._signature_columns = ["messages", "images"] + else: + self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """