Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
4394a90
Created SmolVLM2 model in maestro
AshAnand34 May 10, 2025
0518c67
Fixing lint errors and crated trainer for training dataset in smolvlm2
AshAnand34 May 10, 2025
727e01b
SmolVLM2 documented
AshAnand34 May 10, 2025
d074602
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 10, 2025
6c35bef
fixing errors in smolvlm2 interpretation
AshAnand34 May 10, 2025
7b49423
Merge branches 'feature/add-smolvlm2' and 'feature/add-smolvlm2' of h…
AshAnand34 May 10, 2025
8fe68f2
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 10, 2025
4ff3d63
Fixing more errors with core.py
AshAnand34 May 10, 2025
0e1804e
fix(pre_commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] May 10, 2025
3ea5544
Fixed Ruff error with too long line
AshAnand34 May 10, 2025
4039ff4
first attempt of smolvlm
AlexBodner May 28, 2025
fd521f5
first attempt of smolvlm
AlexBodner May 28, 2025
d46e1f5
changed model type to AutoModelForImageTextToText
AlexBodner May 28, 2025
7a6747f
updated train_collate_fn format
AlexBodner May 28, 2025
1c7d6a6
trying casting model to bfloat 16
AlexBodner May 28, 2025
a28af8f
trying casting model to device
AlexBodner May 28, 2025
a767b8d
trying changing the freezing
AlexBodner May 28, 2025
f906134
removed max length
AlexBodner May 28, 2025
a814cb2
added attention masks
AlexBodner May 28, 2025
23fa5a1
trying different train collate fn
AlexBodner May 28, 2025
023d817
trying different train collate fn
AlexBodner May 28, 2025
4135537
added debuging prints
AlexBodner May 28, 2025
022c61f
removed debuging prints
AlexBodner May 28, 2025
5891867
trying train collate from SmolVLM2_video_FT
AlexBodner May 28, 2025
309894b
added debuging prints
AlexBodner May 28, 2025
3c4e2ce
try padding labels
AlexBodner May 28, 2025
b3109a3
trying charqa format
AlexBodner May 28, 2025
acaa5aa
trying charqa format
AlexBodner May 28, 2025
87c5a26
printing last ids
AlexBodner May 28, 2025
b254416
potential solution
AlexBodner May 28, 2025
beb1b4f
casting pixels to bfloat 16
AlexBodner May 28, 2025
7788dfb
trying autocast
AlexBodner May 28, 2025
a8a27ca
removing bfloat 16 from load model
AlexBodner May 29, 2025
7e34098
small optimization and added evalutation collate_fn
AlexBodner May 29, 2025
1750334
trying float 16
AlexBodner May 29, 2025
0b810f2
removing float type
AlexBodner May 29, 2025
5023150
changed evaluation collate
AlexBodner May 29, 2025
fae5ff0
evaluation collate as train collate
AlexBodner May 29, 2025
257dc78
evaluation collate updated
AlexBodner May 29, 2025
b6c1d18
casting to tensors
AlexBodner May 29, 2025
934e3de
added debugging print
AlexBodner May 29, 2025
877d799
changed processor
AlexBodner May 29, 2025
518c3ca
changes in collates
AlexBodner May 29, 2025
2769c96
trying again paligemma collates
AlexBodner May 29, 2025
62efcde
trying again paligemma collates
AlexBodner May 29, 2025
7ebb56d
trying again paligemma collates
AlexBodner May 29, 2025
434d725
trying again paligemma collates
AlexBodner May 29, 2025
5372867
trying again paligemma collates
AlexBodner May 29, 2025
1681dce
going back to custom collators
AlexBodner May 29, 2025
3698feb
going back to custom collators
AlexBodner May 29, 2025
3afba98
removing images from processor
AlexBodner May 29, 2025
974f157
removing images from processor
AlexBodner May 29, 2025
a259739
rollback
AlexBodner May 29, 2025
e35e1c4
trying to apply chat template to all messages at once
AlexBodner May 29, 2025
d8f58b9
added debug print
AlexBodner May 29, 2025
d598759
added debug print
AlexBodner May 29, 2025
09e7498
testing different things in valid step
AlexBodner May 29, 2025
b20455d
testing different things in valid step
AlexBodner May 29, 2025
5299461
testing different things in train step
AlexBodner May 29, 2025
476e120
testing different things in train step
AlexBodner May 29, 2025
b488446
validation using predict with inputs
AlexBodner May 29, 2025
a600a1e
added device
AlexBodner May 29, 2025
4667b3a
nicer code
AlexBodner May 29, 2025
508f397
added predict
AlexBodner May 29, 2025
ed27b5a
added predict
AlexBodner May 29, 2025
f9f6c2f
added predict
AlexBodner May 29, 2025
e4fed49
added predict
AlexBodner May 29, 2025
d7e40f7
updated predict with input to only return the suffix
AlexBodner May 30, 2025
94d905b
fix validation output
AlexBodner May 30, 2025
ba855a7
tryng collator from HF
AlexBodner May 30, 2025
2b13c69
tryng collator from HF
AlexBodner May 30, 2025
00f3860
rollback
AlexBodner May 30, 2025
6c81298
possible fix for masking
AlexBodner May 30, 2025
1e13d19
buug fix for masking
AlexBodner May 30, 2025
e92064e
debug print
AlexBodner May 30, 2025
57634e0
trying without masking the input
AlexBodner May 30, 2025
9af7a2b
removed comments
AlexBodner May 30, 2025
1f57cb1
Changed with piotr colab suggestions
AlexBodner Jun 2, 2025
bb04f69
removed flash attn
AlexBodner Jun 2, 2025
869048a
removed underscore
AlexBodner Jun 2, 2025
9ff898d
added max image size
AlexBodner Jun 2, 2025
c57e61c
added video sampling size
AlexBodner Jun 2, 2025
33e28d3
added video sampling size
AlexBodner Jun 2, 2025
749a1c9
rollback and added modules to toml
AlexBodner Jun 2, 2025
32e683e
added num2words version
AlexBodner Jun 2, 2025
03a5b33
fixed ruff and mypy issues
AlexBodner Jun 2, 2025
5c270a9
Update smolvlm2.md
AlexBodner Jun 6, 2025
045dcb3
Update smolvlm2.md
AlexBodner Jun 6, 2025
fe1292d
Update smolvlm2.md
AlexBodner Jun 6, 2025
2255e93
Update introspection.py
AlexBodner Jun 6, 2025
a405fc2
Update introspection.py
AlexBodner Jun 6, 2025
9d6180a
Update introspection.py
AlexBodner Jun 6, 2025
3787e41
Update and rename smolvlm2.md to smolvlm_2.md
AlexBodner Jun 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ we recommend creating a dedicated Python environment for each model.
pip install "maestro[qwen_2_5_vl]"
```

=== "SmolVLM2"

```bash
pip install "maestro[smolvlm2]"
```

### CLI

Kick off fine-tuning with our command-line interface, which leverages the configuration
and training routines defined in each models core module. Simply specify key parameters such as
and training routines defined in each model's core module. Simply specify key parameters such as
the dataset location, number of epochs, batch size, optimization strategy, and metrics.

=== "Florence-2"
Expand Down Expand Up @@ -108,6 +114,17 @@ the dataset location, number of epochs, batch size, optimization strategy, and m
--metrics "edit_distance"
```

=== "SmolVLM2"

```bash
maestro smolvlm2 train \
--dataset "dataset/location" \
--epochs 10 \
--batch-size 4 \
--optimization_strategy "lora" \
--metrics "edit_distance"
```

### Python

For greater control, use the Python API to fine-tune your models.
Expand Down Expand Up @@ -148,7 +165,6 @@ and training setup.
```

=== "Qwen2.5-VL"

```python
from maestro.trainer.models.qwen_2_5_vl.core import train

Expand All @@ -162,3 +178,18 @@ and training setup.

train(config)
```

=== "SmolVLM2"
```python
from maestro.trainer.models.smolvlm2.core import train

config = {
"dataset": "dataset/location",
"epochs": 10,
"batch_size": 4,
"optimization_strategy": "lora",
"metrics": ["edit_distance"],
}

train(config)
```
91 changes: 91 additions & 0 deletions docs/models/smolvlm_2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
---
comments: true
---

## Overview

SmolVLM2 is a lightweight vision-language model developed by Hugging Face. It offers impressive capabilities for multimodal understanding while maintaining a compact size compared to larger VLMs. The model excels at tasks such as image captioning, visual question answering, and object detection, making it accessible for applications with limited computational resources.

Built to balance performance and efficiency, SmolVLM2 provides a valuable option for developers seeking to implement vision-language capabilities without the overhead of larger models. The 500M parameter variant delivers practical results while being significantly more resource-friendly than multi-billion parameter alternatives.

## Install

```bash
pip install "maestro[smolvlm_2]"
```

## Train

The training routines support various optimization strategies such as LoRA, QLoRA, and freezing the vision encoder. Customize your fine-tuning process via CLI or Python to align with your dataset and task requirements.

### CLI

Kick off training from the command line by running the command below. Be sure to replace the dataset path and adjust the hyperparameters (such as epochs and batch size) to suit your needs.

```bash
maestro smolvlm_2 train \
--model_id "HuggingFaceTB/SmolVLM-500M-Instruct" \
--dataset "dataset/location" \
--epochs 10 \
--batch-size 4 \
--accumulate_grad_batches 4 \
--optimization_strategy "lora" \
--metrics "edit_distance"
```



### Python
```python
from maestro.trainer.models.smolvlm_2.core import train

config = {
"model_id": "HuggingFaceTB/SmolVLM-500M-Instruct",
"dataset": "dataset/location",
"lr": 2e-5,
"epochs": 10,
"batch_size": 4,
"accumulate_grad_batches": 4,
"num_workers": 0,
"optimization_strategy": "lora",
"metrics": ["edit_distance"],
"device": "cuda"
}


train(config)
```


## Load

Load a pre-trained or fine-tuned SmolVLM model along with its processor using the load_model function. Specify your model's path and the desired optimization strategy.

```python
from maestro.trainer.models.smolvlm_2.checkpoints import (
OptimizationStrategy, load_model
)

processor, model = load_model(
model_id_or_path="model/location",
optimization_strategy=OptimizationStrategy.NONE
)
```
## Predict

Perform inference with SmolVLM using the predict function. Supply an image and a text prefix to obtain predictions, such as object detection outputs or captions.

```python
from maestro.trainer.common.datasets.jsonl import JSONLDataset
from maestro.trainer.models.smolvlm_2.inference import predict

ds = JSONLDataset(
jsonl_file_path="dataset/location/test/annotations.jsonl",
image_directory_path="dataset/location/test",
)

image, entry = ds[0]

predict(model=model, processor=processor, image=image, prefix=entry["prefix"])
```

7 changes: 7 additions & 0 deletions maestro/cli/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def find_training_recipes(app: typer.Typer) -> None:
except Exception:
_warn_about_recipe_import_error(model_name="Qwen2.5-VL")

try:
from maestro.trainer.models.smolvlm_2.entrypoint import smolvlm_2_app

app.add_typer(smolvlm2_app, name="smolvlm_2")
except Exception:
_warn_about_recipe_import_error(model_name="SmolVLM2")


def _warn_about_recipe_import_error(model_name: str) -> None:
disable_warnings = str2bool(
Expand Down
Empty file.
158 changes: 158 additions & 0 deletions maestro/trainer/models/smolvlm_2/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from enum import Enum
from typing import Optional

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig

from maestro.trainer.common.utils.device import parse_device_spec
from maestro.trainer.logger import get_maestro_logger

DEFAULT_SMOLVLM_2_MODEL_ID = "HuggingFaceTB/SmolVLM-500M-Instruct" # "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
DEFAULT_SMOLVLM_2_MODEL_REVISION = "refs/heads/main"
DEFAULT_SMOLVLM_2_LORA_PARAMS = {
"r": 8,
"lora_alpha": 8,
"lora_dropout": 0.1,
"bias": "none",
"target_modules": ["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
"init_lora_weights": "gaussian",
"use_dora": True,
}
DEFAULT_SMOLVLM_2_QLORA_PARAMS = {
"r": 8,
"lora_alpha": 8,
"lora_dropout": 0.1,
"bias": "none",
"target_modules": ["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
"init_lora_weights": "gaussian",
"use_dora": False,
}
logger = get_maestro_logger()


def save_checkpoint(
model: AutoModelForImageTextToText, processor: AutoProcessor, path: str, metadata: Optional[dict] = None
) -> None:
"""
Save model checkpoint.

Args:
model: Model to save
processor: Processor to save
path: Path to save checkpoint
metadata: Optional metadata to save
"""
os.makedirs(path, exist_ok=True)

# Save model
model.save_pretrained(path)

# Save processor
processor.save_pretrained(path)

# Save metadata if provided
if metadata is not None:
torch.save(metadata, os.path.join(path, "metadata.pt"))


def save_model(
target_dir: str,
processor: AutoProcessor,
model: AutoModelForImageTextToText,
) -> None:
"""
Save a SmolVLM 2 model and its processor to disk.

Args:
target_dir: Directory path where the model and processor will be saved.
Will be created if it doesn't exist.
processor: The SmolVLM 2 processor to save.
model: The SmolVLM 2model to save.
"""
os.makedirs(target_dir, exist_ok=True)
processor.save_pretrained(target_dir)
model.save_pretrained(target_dir)


class OptimizationStrategy(Enum):
"""Enumeration for optimization strategies."""

LORA = "lora"
QLORA = "qlora"
FREEZE = "freeze"
NONE = "none"


def load_model(
model_id_or_path: str = DEFAULT_SMOLVLM_2_MODEL_ID,
revision: str = DEFAULT_SMOLVLM_2_MODEL_REVISION,
device: str | torch.device = "auto",
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
peft_advanced_params: Optional[dict] = None,
cache_dir: Optional[str] = None,
longest_edge: int = 512,
) -> tuple[AutoProcessor, AutoModelForImageTextToText]:
device = parse_device_spec(device)
processor = AutoProcessor.from_pretrained(
model_id_or_path, do_resize=True, size={"longest_edge": longest_edge}, trust_remote_code=True, revision=revision
)

if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}:
default_params = (
DEFAULT_SMOLVLM_2_QLORA_PARAMS
if optimization_strategy == OptimizationStrategy.QLORA
else DEFAULT_SMOLVLM_2_LORA_PARAMS
)
if peft_advanced_params is not None:
default_params.update(peft_advanced_params)
try:
lora_config = LoraConfig(**default_params)
logger.info("Successfully created LoraConfig")
except TypeError:
logger.exception("Invalid parameters for LoraConfig")
raise
else:
logger.info("No additiopnal LoRA parameters provided. Using default configuration.")
lora_config = LoraConfig(**default_params)

bnb_config = (
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
if optimization_strategy == OptimizationStrategy.QLORA
else None
)

model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
# _attn_implementation="flash_attention_2",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
else:
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map="auto",
cache_dir=cache_dir,
torch_dtype=torch.bfloat16,
# _attn_implementation="flash_attention_2"
).to(device)

if optimization_strategy == OptimizationStrategy.FREEZE:
for param in model.model.vision_model.parameters():
param.requires_grad = False

return processor, model
Loading