diff --git a/docs/index.md b/docs/index.md index f6266980..82dd2518 100644 --- a/docs/index.md +++ b/docs/index.md @@ -69,10 +69,16 @@ we recommend creating a dedicated Python environment for each model. pip install "maestro[qwen_2_5_vl]" ``` +=== "SmolVLM2" + + ```bash + pip install "maestro[smolvlm2]" + ``` + ### CLI Kick off fine-tuning with our command-line interface, which leverages the configuration -and training routines defined in each model’s core module. Simply specify key parameters such as +and training routines defined in each model's core module. Simply specify key parameters such as the dataset location, number of epochs, batch size, optimization strategy, and metrics. === "Florence-2" @@ -108,6 +114,17 @@ the dataset location, number of epochs, batch size, optimization strategy, and m --metrics "edit_distance" ``` +=== "SmolVLM2" + + ```bash + maestro smolvlm2 train \ + --dataset "dataset/location" \ + --epochs 10 \ + --batch-size 4 \ + --optimization_strategy "lora" \ + --metrics "edit_distance" + ``` + ### Python For greater control, use the Python API to fine-tune your models. @@ -148,7 +165,6 @@ and training setup. ``` === "Qwen2.5-VL" - ```python from maestro.trainer.models.qwen_2_5_vl.core import train @@ -162,3 +178,18 @@ and training setup. train(config) ``` + +=== "SmolVLM2" + ```python + from maestro.trainer.models.smolvlm2.core import train + + config = { + "dataset": "dataset/location", + "epochs": 10, + "batch_size": 4, + "optimization_strategy": "lora", + "metrics": ["edit_distance"], + } + + train(config) + ``` diff --git a/docs/models/smolvlm_2.md b/docs/models/smolvlm_2.md new file mode 100644 index 00000000..aa62b7d9 --- /dev/null +++ b/docs/models/smolvlm_2.md @@ -0,0 +1,91 @@ +--- +comments: true +--- + +## Overview + +SmolVLM2 is a lightweight vision-language model developed by Hugging Face. It offers impressive capabilities for multimodal understanding while maintaining a compact size compared to larger VLMs. The model excels at tasks such as image captioning, visual question answering, and object detection, making it accessible for applications with limited computational resources. + +Built to balance performance and efficiency, SmolVLM2 provides a valuable option for developers seeking to implement vision-language capabilities without the overhead of larger models. The 500M parameter variant delivers practical results while being significantly more resource-friendly than multi-billion parameter alternatives. + +## Install + +```bash +pip install "maestro[smolvlm_2]" +``` + +## Train + +The training routines support various optimization strategies such as LoRA, QLoRA, and freezing the vision encoder. Customize your fine-tuning process via CLI or Python to align with your dataset and task requirements. + +### CLI + +Kick off training from the command line by running the command below. Be sure to replace the dataset path and adjust the hyperparameters (such as epochs and batch size) to suit your needs. + +```bash +maestro smolvlm_2 train \ + --model_id "HuggingFaceTB/SmolVLM-500M-Instruct" \ + --dataset "dataset/location" \ + --epochs 10 \ + --batch-size 4 \ + --accumulate_grad_batches 4 \ + --optimization_strategy "lora" \ + --metrics "edit_distance" +``` + + + +### Python +```python +from maestro.trainer.models.smolvlm_2.core import train + +config = { + "model_id": "HuggingFaceTB/SmolVLM-500M-Instruct", + "dataset": "dataset/location", + "lr": 2e-5, + "epochs": 10, + "batch_size": 4, + "accumulate_grad_batches": 4, + "num_workers": 0, + "optimization_strategy": "lora", + "metrics": ["edit_distance"], + "device": "cuda" +} + + +train(config) +``` + + +## Load + +Load a pre-trained or fine-tuned SmolVLM model along with its processor using the load_model function. Specify your model's path and the desired optimization strategy. + +```python +from maestro.trainer.models.smolvlm_2.checkpoints import ( + OptimizationStrategy, load_model +) + +processor, model = load_model( + model_id_or_path="model/location", + optimization_strategy=OptimizationStrategy.NONE +) +``` +## Predict + +Perform inference with SmolVLM using the predict function. Supply an image and a text prefix to obtain predictions, such as object detection outputs or captions. + +```python +from maestro.trainer.common.datasets.jsonl import JSONLDataset +from maestro.trainer.models.smolvlm_2.inference import predict + +ds = JSONLDataset( + jsonl_file_path="dataset/location/test/annotations.jsonl", + image_directory_path="dataset/location/test", +) + +image, entry = ds[0] + +predict(model=model, processor=processor, image=image, prefix=entry["prefix"]) +``` + diff --git a/maestro/cli/introspection.py b/maestro/cli/introspection.py index 086a831b..f383429a 100644 --- a/maestro/cli/introspection.py +++ b/maestro/cli/introspection.py @@ -28,6 +28,13 @@ def find_training_recipes(app: typer.Typer) -> None: except Exception: _warn_about_recipe_import_error(model_name="Qwen2.5-VL") + try: + from maestro.trainer.models.smolvlm_2.entrypoint import smolvlm_2_app + + app.add_typer(smolvlm2_app, name="smolvlm_2") + except Exception: + _warn_about_recipe_import_error(model_name="SmolVLM2") + def _warn_about_recipe_import_error(model_name: str) -> None: disable_warnings = str2bool( diff --git a/maestro/trainer/models/smolvlm_2/__init__.py b/maestro/trainer/models/smolvlm_2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maestro/trainer/models/smolvlm_2/checkpoints.py b/maestro/trainer/models/smolvlm_2/checkpoints.py new file mode 100644 index 00000000..99ebafbf --- /dev/null +++ b/maestro/trainer/models/smolvlm_2/checkpoints.py @@ -0,0 +1,158 @@ +import os +from enum import Enum +from typing import Optional + +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig + +from maestro.trainer.common.utils.device import parse_device_spec +from maestro.trainer.logger import get_maestro_logger + +DEFAULT_SMOLVLM_2_MODEL_ID = "HuggingFaceTB/SmolVLM-500M-Instruct" # "HuggingFaceTB/SmolVLM2-2.2B-Instruct" +DEFAULT_SMOLVLM_2_MODEL_REVISION = "refs/heads/main" +DEFAULT_SMOLVLM_2_LORA_PARAMS = { + "r": 8, + "lora_alpha": 8, + "lora_dropout": 0.1, + "bias": "none", + "target_modules": ["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"], + "init_lora_weights": "gaussian", + "use_dora": True, +} +DEFAULT_SMOLVLM_2_QLORA_PARAMS = { + "r": 8, + "lora_alpha": 8, + "lora_dropout": 0.1, + "bias": "none", + "target_modules": ["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"], + "init_lora_weights": "gaussian", + "use_dora": False, +} +logger = get_maestro_logger() + + +def save_checkpoint( + model: AutoModelForImageTextToText, processor: AutoProcessor, path: str, metadata: Optional[dict] = None +) -> None: + """ + Save model checkpoint. + + Args: + model: Model to save + processor: Processor to save + path: Path to save checkpoint + metadata: Optional metadata to save + """ + os.makedirs(path, exist_ok=True) + + # Save model + model.save_pretrained(path) + + # Save processor + processor.save_pretrained(path) + + # Save metadata if provided + if metadata is not None: + torch.save(metadata, os.path.join(path, "metadata.pt")) + + +def save_model( + target_dir: str, + processor: AutoProcessor, + model: AutoModelForImageTextToText, +) -> None: + """ + Save a SmolVLM 2 model and its processor to disk. + + Args: + target_dir: Directory path where the model and processor will be saved. + Will be created if it doesn't exist. + processor: The SmolVLM 2 processor to save. + model: The SmolVLM 2model to save. + """ + os.makedirs(target_dir, exist_ok=True) + processor.save_pretrained(target_dir) + model.save_pretrained(target_dir) + + +class OptimizationStrategy(Enum): + """Enumeration for optimization strategies.""" + + LORA = "lora" + QLORA = "qlora" + FREEZE = "freeze" + NONE = "none" + + +def load_model( + model_id_or_path: str = DEFAULT_SMOLVLM_2_MODEL_ID, + revision: str = DEFAULT_SMOLVLM_2_MODEL_REVISION, + device: str | torch.device = "auto", + optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE, + peft_advanced_params: Optional[dict] = None, + cache_dir: Optional[str] = None, + longest_edge: int = 512, +) -> tuple[AutoProcessor, AutoModelForImageTextToText]: + device = parse_device_spec(device) + processor = AutoProcessor.from_pretrained( + model_id_or_path, do_resize=True, size={"longest_edge": longest_edge}, trust_remote_code=True, revision=revision + ) + + if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}: + default_params = ( + DEFAULT_SMOLVLM_2_QLORA_PARAMS + if optimization_strategy == OptimizationStrategy.QLORA + else DEFAULT_SMOLVLM_2_LORA_PARAMS + ) + if peft_advanced_params is not None: + default_params.update(peft_advanced_params) + try: + lora_config = LoraConfig(**default_params) + logger.info("Successfully created LoraConfig") + except TypeError: + logger.exception("Invalid parameters for LoraConfig") + raise + else: + logger.info("No additiopnal LoRA parameters provided. Using default configuration.") + lora_config = LoraConfig(**default_params) + + bnb_config = ( + BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + if optimization_strategy == OptimizationStrategy.QLORA + else None + ) + + model = AutoModelForImageTextToText.from_pretrained( + pretrained_model_name_or_path=model_id_or_path, + revision=revision, + trust_remote_code=True, + device_map="auto", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + # _attn_implementation="flash_attention_2", + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + else: + model = AutoModelForImageTextToText.from_pretrained( + pretrained_model_name_or_path=model_id_or_path, + revision=revision, + trust_remote_code=True, + device_map="auto", + cache_dir=cache_dir, + torch_dtype=torch.bfloat16, + # _attn_implementation="flash_attention_2" + ).to(device) + + if optimization_strategy == OptimizationStrategy.FREEZE: + for param in model.model.vision_model.parameters(): + param.requires_grad = False + + return processor, model diff --git a/maestro/trainer/models/smolvlm_2/core.py b/maestro/trainer/models/smolvlm_2/core.py new file mode 100644 index 00000000..d08a0ec5 --- /dev/null +++ b/maestro/trainer/models/smolvlm_2/core.py @@ -0,0 +1,204 @@ +import os +from dataclasses import dataclass, field, replace +from functools import partial +from typing import Literal, Optional + +import dacite +import lightning +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForVision2Seq, AutoProcessor + +from maestro.trainer.common.callbacks import SaveCheckpoint +from maestro.trainer.common.datasets.core import create_data_loaders, resolve_dataset_path +from maestro.trainer.common.metrics import BaseMetric, MetricsTracker, parse_metrics, save_metric_plots +from maestro.trainer.common.training import MaestroTrainer +from maestro.trainer.common.utils.device import device_is_available, parse_device_spec +from maestro.trainer.common.utils.path import create_new_run_directory +from maestro.trainer.common.utils.seed import ensure_reproducibility +from maestro.trainer.logger import get_maestro_logger +from maestro.trainer.models.smolvlm_2.checkpoints import ( + DEFAULT_SMOLVLM_2_MODEL_ID, + DEFAULT_SMOLVLM_2_MODEL_REVISION, + OptimizationStrategy, + load_model, + save_model, +) +from maestro.trainer.models.smolvlm_2.inference import predict_with_inputs +from maestro.trainer.models.smolvlm_2.loaders import evaluation_collate_fn, train_collate_fn + +logger = get_maestro_logger() + + +@dataclass() +class SmolVLM2Configuration: + dataset: str + model_id: str = DEFAULT_SMOLVLM_2_MODEL_ID + revision: str = DEFAULT_SMOLVLM_2_MODEL_REVISION + device: str | torch.device = "auto" + optimization_strategy: Literal["lora", "qlora", "freeze", "none"] = "lora" + cache_dir: Optional[str] = None + epochs: int = 10 + lr: float = 1e-4 + batch_size: int = 4 + accumulate_grad_batches: int = 4 + val_batch_size: Optional[int] = None + num_workers: int = 0 + val_num_workers: Optional[int] = None + output_dir: str = "./training/smolvlm_2" + metrics: list[BaseMetric] | list[str] = field(default_factory=list) + system_message: Optional[str] = None + max_new_tokens: int = 64 + random_seed: Optional[int] = None + peft_advanced_params: Optional[dict] = None + + def __post_init__(self): + if self.val_batch_size is None: + self.val_batch_size = self.batch_size + + if self.val_num_workers is None: + self.val_num_workers = self.num_workers + + if isinstance(self.metrics, list) and all(isinstance(m, str) for m in self.metrics): + self.metrics = parse_metrics(self.metrics) + + self.device = parse_device_spec(self.device) + if not device_is_available(self.device): + raise ValueError(f"Requested device '{self.device}' is not available.") + + +class SmolVLM2Trainer(MaestroTrainer): + """ + Trainer for fine-tuning the SmolVLM-2 model. + + Attributes: + processor (AutoProcessor): Processor for model inputs. + model (AutoModelForImageTextToText): The SmolVLM-2 model. + train_loader (DataLoader): DataLoader for training data. + valid_loader (DataLoader): DataLoader for validation data. + config (SmolVLM2Configuration): Configuration object with training parameters. + """ + + def __init__( + self, + processor: AutoProcessor, + model: AutoModelForVision2Seq, + train_loader: DataLoader, + valid_loader: DataLoader, + config: SmolVLM2Configuration, + ): + super().__init__(processor, model, train_loader, valid_loader) + self.config = config + + # TODO: Redesign metric tracking system + self.train_metrics_tracker = MetricsTracker.init(metrics=["loss"]) + metrics = ["loss"] + for metric in config.metrics: + if isinstance(metric, BaseMetric): + metrics += metric.describe() + self.valid_metrics_tracker = MetricsTracker.init(metrics=metrics) + + def training_step(self, batch, batch_idx): + input_ids, attention_mask, pixel_values, pixel_attention_mask, labels = batch + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + labels=labels, + ) + loss = outputs.loss + self.log("train_loss", loss, prog_bar=True, logger=True, batch_size=self.config.batch_size) + self.train_metrics_tracker.register("loss", epoch=self.current_epoch, step=batch_idx, value=loss.item()) + return loss + + def validation_step(self, batch, batch_idx): + input_ids, attention_mask, pixel_values, pixel_attention_mask, images, prefixes, suffixes = batch + generated_suffixes = predict_with_inputs( + model=self.model, + processor=self.processor, + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + if batch_idx == 0: + logger.info(f"sample valid prefix: {prefixes[0]}") + logger.info(f"sample valid suffix: {suffixes[0]}") + logger.info(f"sample generated suffix: {generated_suffixes[0]}") + + for metric in self.config.metrics: + result = metric.compute(predictions=generated_suffixes, targets=suffixes) + for key, value in result.items(): + self.valid_metrics_tracker.register( + metric=key, + epoch=self.current_epoch, + step=batch_idx, + value=value, + ) + self.log(key, value, prog_bar=True, logger=True, batch_size=self.config.val_batch_size) + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.config.lr) + return optimizer + + def on_fit_end(self) -> None: + save_metrics_path = os.path.join(self.config.output_dir, "metrics") + save_metric_plots( + training_tracker=self.train_metrics_tracker, + validation_tracker=self.valid_metrics_tracker, + output_dir=save_metrics_path, + ) + + +def train(config: SmolVLM2Configuration | dict) -> None: + if isinstance(config, dict): + config = dacite.from_dict(data_class=SmolVLM2Configuration, data=config) + assert isinstance(config, SmolVLM2Configuration) # ensure mypy understands it's not a dict + + ensure_reproducibility(seed=config.random_seed, avoid_non_deterministic_algorithms=False) + run_dir = create_new_run_directory(base_output_dir=config.output_dir) + config = replace(config, output_dir=run_dir) + + processor, model = load_model( + model_id_or_path=config.model_id, + revision=config.revision, + device=config.device, + optimization_strategy=OptimizationStrategy(config.optimization_strategy), + peft_advanced_params=config.peft_advanced_params, + cache_dir=config.cache_dir, + ) + dataset_location = resolve_dataset_path(config.dataset) + if dataset_location is None: + return + + train_loader, valid_loader, test_loader = create_data_loaders( + dataset_location=dataset_location, + train_batch_size=config.batch_size, + train_collect_fn=partial(train_collate_fn, processor=processor), + train_num_workers=config.num_workers, + test_batch_size=config.val_batch_size, + test_collect_fn=partial(evaluation_collate_fn, processor=processor), + test_num_workers=config.val_num_workers, + ) + + _, train_entry = train_loader.dataset[0] + logger.info(f"sample train prefix: {train_entry['prefix']}") + logger.info(f"sample train suffix: {train_entry['suffix']}") + + pl_module = SmolVLM2Trainer( + processor=processor, model=model, train_loader=train_loader, valid_loader=valid_loader, config=config + ) + save_checkpoints_path = os.path.join(config.output_dir, "checkpoints") + save_checkpoint_callback = SaveCheckpoint(result_path=save_checkpoints_path, save_model_callback=save_model) + trainer = lightning.Trainer( + max_epochs=config.epochs, + accumulate_grad_batches=config.accumulate_grad_batches, + check_val_every_n_epoch=1, + limit_val_batches=1, + log_every_n_steps=10, + callbacks=[save_checkpoint_callback], + ) + trainer.fit(pl_module) diff --git a/maestro/trainer/models/smolvlm_2/entrypoint.py b/maestro/trainer/models/smolvlm_2/entrypoint.py new file mode 100644 index 00000000..7f84e6a4 --- /dev/null +++ b/maestro/trainer/models/smolvlm_2/entrypoint.py @@ -0,0 +1,110 @@ +import dataclasses +import json +from typing import Annotated, Any, Optional + +import rich +import typer + +from maestro.trainer.logger import get_maestro_logger +from maestro.trainer.models.smolvlm_2.checkpoints import DEFAULT_SMOLVLM_2_MODEL_ID, DEFAULT_SMOLVLM_2_MODEL_REVISION +from maestro.trainer.models.smolvlm_2.core import SmolVLM2Configuration +from maestro.trainer.models.smolvlm_2.core import train as smolvlm_2_train + +logger = get_maestro_logger() +smolvlm_2_app = typer.Typer(help="Fine-tune and evaluate SmolVLM_2 model") + + +@smolvlm_2_app.command( + help="Train SmolVLM_2 model", + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, +) +def train( + dataset: Annotated[ + str, + typer.Option( + "--dataset", + help="Local path or Roboflow identifier. If not found locally, it will be resolved (and downloaded) " + "automatically", + ), + ], + model_id: Annotated[ + str, typer.Option("--model_id", help="Identifier for the SmolVLM_2 model") + ] = DEFAULT_SMOLVLM_2_MODEL_ID, + revision: Annotated[ + str, typer.Option("--revision", help="Model revision to use") + ] = DEFAULT_SMOLVLM_2_MODEL_REVISION, + device: Annotated[str, typer.Option("--device", help="Device to use for training")] = "auto", + optimization_strategy: Annotated[ + str, typer.Option("--optimization_strategy", help="Optimization strategy: lora, freeze, or none") + ] = "lora", + cache_dir: Annotated[ + Optional[str], typer.Option("--cache_dir", help="Directory to cache the model weights locally") + ] = None, + epochs: Annotated[int, typer.Option("--epochs", help="Number of training epochs")] = 10, + lr: Annotated[float, typer.Option("--lr", help="Learning rate for training")] = 1e-5, + batch_size: Annotated[int, typer.Option("--batch_size", help="Training batch size")] = 4, + accumulate_grad_batches: Annotated[ + int, typer.Option("--accumulate_grad_batches", help="Number of batches to accumulate for gradient updates") + ] = 8, + val_batch_size: Annotated[Optional[int], typer.Option("--val_batch_size", help="Validation batch size")] = None, + num_workers: Annotated[int, typer.Option("--num_workers", help="Number of workers for data loading")] = 0, + val_num_workers: Annotated[ + Optional[int], typer.Option("--val_num_workers", help="Number of workers for validation data loading") + ] = None, + output_dir: Annotated[ + str, typer.Option("--output_dir", help="Directory to store training outputs") + ] = "./training/smolvlm_2", + metrics: Annotated[list[str], typer.Option("--metrics", help="List of metrics to track during training")] = [], + max_new_tokens: Annotated[ + int, + typer.Option("--max_new_tokens", help="Maximum number of new tokens generated during inference"), + ] = 1024, + random_seed: Annotated[ + Optional[int], + typer.Option("--random_seed", help="Random seed for ensuring reproducibility. If None, no seed is set"), + ] = None, + peft_advanced_params: Annotated[ + Optional[str], + typer.Option("--peft_advanced_params", help="custom LoRA config. If None, default LoRA config is set"), + ] = None, +) -> None: + def parse_lora_params(param_str) -> dict[str, Any]: + parsed_params = json.loads(param_str) + if not isinstance(parsed_params, dict): + raise TypeError("Parsed JSON is not a dictionary") + return parsed_params + + if peft_advanced_params is not None: + try: + peft_advanced_params_dict = parse_lora_params(peft_advanced_params) + logger.info(f"Parsed LoRA parameters: {peft_advanced_params_dict}") + except json.JSONDecodeError: + logger.exception("Failed to parse JSON") + raise + except TypeError: + logger.exception("Invalid LoRA parameter format") + raise + + config = SmolVLM2Configuration( + dataset=dataset, + model_id=model_id, + revision=revision, + device=device, + optimization_strategy=optimization_strategy, # type: ignore + cache_dir=cache_dir, + epochs=epochs, + lr=lr, + batch_size=batch_size, + accumulate_grad_batches=accumulate_grad_batches, + val_batch_size=val_batch_size, + num_workers=num_workers, + val_num_workers=val_num_workers, + output_dir=output_dir, + metrics=metrics, + max_new_tokens=max_new_tokens, + random_seed=random_seed, + peft_advanced_params=peft_advanced_params_dict, + ) + typer.echo(typer.style("Training configuration", fg=typer.colors.BRIGHT_GREEN, bold=True)) + rich.print(dataclasses.asdict(config)) + smolvlm_2_train(config=config) diff --git a/maestro/trainer/models/smolvlm_2/inference.py b/maestro/trainer/models/smolvlm_2/inference.py new file mode 100644 index 00000000..a4c707c4 --- /dev/null +++ b/maestro/trainer/models/smolvlm_2/inference.py @@ -0,0 +1,56 @@ +import torch +from PIL import Image +from transformers import AutoModelForImageTextToText, AutoProcessor + +from maestro.trainer.common.utils.device import parse_device_spec + + +def predict_with_inputs( + model: AutoModelForImageTextToText, + processor: AutoProcessor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + max_new_tokens: int = 64, +) -> list[str]: + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + pixel_attention_mask=pixel_attention_mask, + do_sample=False, + max_new_tokens=max_new_tokens, + ) + prefix_length = input_ids.shape[-1] + generated_ids = generated_ids[:, prefix_length:] + return processor.batch_decode(generated_ids, skip_special_tokens=True) + + +def predict( + model: AutoModelForImageTextToText, + processor: AutoProcessor, + image: str | bytes | Image.Image, + prefix: str, + device: str | torch.device = "auto", + max_new_tokens: int = 64, +) -> str: + device = parse_device_spec(device) + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prefix}, + ], + }, + ] + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ).to(device, dtype=torch.bfloat16) + return predict_with_inputs(**inputs, model=model, processor=processor, max_new_tokens=max_new_tokens)[0] diff --git a/maestro/trainer/models/smolvlm_2/loaders.py b/maestro/trainer/models/smolvlm_2/loaders.py new file mode 100644 index 00000000..605db7e8 --- /dev/null +++ b/maestro/trainer/models/smolvlm_2/loaders.py @@ -0,0 +1,130 @@ +from typing import Any + +import torch +from PIL import Image +from transformers import AutoProcessor + +from maestro.trainer.common.utils.device import parse_device_spec + + +def format_conversation( + image: str | bytes | Image.Image, prefix: str, suffix: str | None = None, system_message: str | None = None +) -> list[dict]: + messages = [] + + if system_message is not None: + messages.append( + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + } + ) + + messages.append( + { + "role": "user", + "content": [ + { + "type": "image", + "image": image, + }, + { + "type": "text", + "text": prefix, + }, + ], + } + ) + + if suffix is not None: + messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": suffix}], + } + ) + + return messages + + +def train_collate_fn( + batch: list[tuple[Image.Image, dict[str, Any]]], + processor: AutoProcessor, + system_message: str | None = None, + device: str | torch.device = "auto", +): + device = parse_device_spec(device) + images, data = zip(*batch) + conversations = [ + format_conversation(image, entry["prefix"], entry["suffix"], system_message) + for image, entry in zip(images, data) + ] + texts = [ + processor.apply_chat_template(conversation=conversation, add_generation_prompt=False).strip() + for conversation in conversations + ] + user_conversations = [ + format_conversation(image, entry["prefix"], system_message) for image, entry in zip(images, data) + ] + user_texts = [ + processor.apply_chat_template(conversation=user_conversation, add_generation_prompt=False).strip() + for user_conversation in user_conversations + ] + image_lists = [[image] for image in images] + model_inputs = processor(text=texts, images=image_lists, return_tensors="pt", padding=True).to( + device, dtype=torch.bfloat16 + ) + user_model_inputs = processor(text=user_texts, images=image_lists, return_tensors="pt", padding=True).to( + device, dtype=torch.bfloat16 + ) + + labels = model_inputs["input_ids"].clone() + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + pixel_values = model_inputs["pixel_values"] + pixel_attention_mask = model_inputs["pixel_attention_mask"] + user_input_ids = user_model_inputs["input_ids"] + + for index, user_input_id in enumerate(user_input_ids): + user_input_length = user_input_id.shape[0] + labels[index, :user_input_length] = -100 + + return input_ids, attention_mask, pixel_values, pixel_attention_mask, labels + + +def evaluation_collate_fn( + batch: list[tuple[Image.Image, dict[str, Any]]], + processor: AutoProcessor, + system_message: str | None = None, + device: str | torch.device = "auto", +): + device = parse_device_spec(device) + images, data = zip(*batch) + prefixes = [entry["prefix"] for entry in data] + suffixes = [entry["suffix"] for entry in data] + user_conversations = [ + format_conversation(image, entry["prefix"], system_message) for image, entry in zip(images, data) + ] + user_texts = [ + processor.apply_chat_template(conversation=user_conversation, add_generation_prompt=False).strip() + for user_conversation in user_conversations + ] + image_lists = [[image] for image in images] + user_model_inputs = processor(text=user_texts, images=image_lists, return_tensors="pt", padding=True).to( + device, dtype=torch.bfloat16 + ) + + user_input_ids = user_model_inputs["input_ids"] + user_attention_mask = user_model_inputs["attention_mask"] + user_pixel_values = user_model_inputs["pixel_values"] + user_pixel_attention_mask = user_model_inputs["pixel_attention_mask"] + + return ( + user_input_ids, + user_attention_mask, + user_pixel_values, + user_pixel_attention_mask, + image_lists, + prefixes, + suffixes, + ) diff --git a/mkdocs.yaml b/mkdocs.yaml index 3f476c12..f232889b 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -27,6 +27,7 @@ nav: - Florence-2: models/florence_2.md - PaliGemma 2: models/paligemma_2.md - Qwen2.5-VL: models/qwen_2_5_vl.md + - SmolVLM2: models/smolvlm2.md - Datasets: - JSONL: datasets/jsonl.md diff --git a/pyproject.toml b/pyproject.toml index da476b4f..4d91dc1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,15 @@ qwen_2_5_vl = [ "bitsandbytes>=0.45.0", "qwen-vl-utils>=0.0.8" ] +smolvlm_2 = [ + "accelerate>=1.2.1", + "peft>=0.12", + "torch>=2.4.0", + "torchvision>=0.20.0", + "transformers>=4.49.0", + "bitsandbytes>=0.45.0", + "num2words>=0.54.14" +] [project.scripts] maestro = "maestro.cli.main:app" @@ -147,12 +156,8 @@ line-length = 120 indent-width = 4 [tool.ruff.lint] - -# Enable pycodestyle (`E`) -select = ["E", "F", "I", "A", "Q", "W", "N", "T", "Q","TRY","UP","C90","RUF","NPY"] -ignore = ["T201","TRY003","NPY201"] - -# Allow autofix for all enabled rules (when `--fix`) is provided. +select = ["E", "F", "I", "A", "Q", "W", "N", "T", "TRY", "UP", "C90", "RUF", "NPY"] +ignore = ["T201", "TRY003", "NPY201"] fixable = [ "A", "B", @@ -197,12 +202,15 @@ fixable = [ "TID", "TRY", "UP", - "YTT", + "YTT" ] unfixable = [] + # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -pylint.max-args = 20 + +[tool.ruff.lint.pylint] +max-args = 20 [tool.ruff.lint.flake8-quotes] inline-quotes = "double"