From 8756d8f4d716424b84755140cce8e8bff7c7d3a5 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 14:03:49 -0700 Subject: [PATCH 01/14] feat(mm-cpt): multimodal continued pre-training (raw image+text) Adds a streaming-first multimodal CPT path: raw `(text, images[])` rows are tokenized once with a placeholder-count guardrail, batched through a hardened collator, and fed to a VLM with image-family tokens masked out of labels. Gated by `type: multimodal_pretrain` (or `multimodal: true`) on a `pretraining_dataset` entry; works end-to-end for train and eval, including multi-entry eval and mixed image/text batches. Features -------- - Streaming MM CPT encoder (`encode_streaming_multimodal`): counts placeholders by token id (not substring), enforces `placeholders == len(images)` per row, and rejects rows that exceed `sequence_len` instead of silently truncating mid-placeholder. - MM CPT collator (`MultiModalPretrainDataCollator`): security-hardened image loader (path traversal / NUL byte / remote URL / multi-frame bomb / pixel cap rejection), per-row image cap, processor-call retry that pinpoints the offending row, and label-side masking of every image-family token id. - Mixed image/text batches: text-only rows in a batch take a tokenizer-only fallback (no `pixel_values`); rows with images go through the processor as usual. - Eval support: `test_datasets` accepts MM entries via a dedicated `MultiModalEvalDataset` model so per-entry `text_column` / `image_column` / `image_base_dir` / `image_token` survive validation. Multi-entry MM eval streams are concatenated. - `dispatch_batches: true` support: non-main ranks get a placeholder dataset that mirrors the configured text + image columns. - Config validation gates: `processor_type` required, `sample_packing: false` enforced, `chat_template` rejected, single `pretraining_dataset` entry required, MM eval entries must share `image_base_dir` / `image_token`, mixed MM/non-MM eval rejected, incompatible processor classes (Mllama, Pixtral, InternVL) rejected at startup. `remove_unused_columns` is auto-set to `false` with an INFO log. - Docs: new section in `docs/multimodal.qmd` covering the YAML shape, placeholder-token table, eval contract, and supported/rejected model families. YAML example ------------ base_model: HuggingFaceTB/SmolVLM-500M-Instruct processor_type: AutoProcessor pretraining_dataset: - path: /path/to/shards/*.jsonl ds_type: json type: multimodal_pretrain text_column: text image_column: images image_base_dir: /path/to/images streaming: true sequence_len: 2048 sample_packing: false Tests ----- 59 tests across four suites covering the encoder, collator (including mixed/all-text batches and security gates), prompt strategy, schema preservation, multi-entry eval merge, eval homogeneity validation, eval-aware collator, dispatch-batches placeholder shape, and the auto-set log record. --- docs/multimodal.qmd | 114 +++++ src/axolotl/core/builders/causal.py | 66 +++ .../prompt_strategies/multimodal_pretrain.py | 297 ++++++++++++ src/axolotl/utils/collators/mm_pretrain.py | 306 ++++++++++++ src/axolotl/utils/data/sft.py | 126 +++-- src/axolotl/utils/data/streaming.py | 149 +++++- src/axolotl/utils/schemas/config.py | 4 +- src/axolotl/utils/schemas/datasets.py | 82 ++++ src/axolotl/utils/schemas/validation.py | 87 ++++ tests/conftest.py | 18 + .../test_multimodal_pretrain.py | 242 ++++++++++ tests/test_multimodal_streaming.py | 448 ++++++++++++++++++ tests/utils/data/test_mm_cpt_eval.py | 186 ++++++++ .../schemas/validation/test_multimodal_cpt.py | 270 +++++++++++ 14 files changed, 2358 insertions(+), 37 deletions(-) create mode 100644 src/axolotl/prompt_strategies/multimodal_pretrain.py create mode 100644 src/axolotl/utils/collators/mm_pretrain.py create mode 100644 tests/prompt_strategies/test_multimodal_pretrain.py create mode 100644 tests/test_multimodal_streaming.py create mode 100644 tests/utils/data/test_mm_cpt_eval.py create mode 100644 tests/utils/schemas/validation/test_multimodal_cpt.py diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 5197f48b10..b4d7740dc3 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -364,6 +364,120 @@ Here is an example of a multi-modal dataset: ] ``` +## Continued Pre-training (CPT) with images {#sec-multimodal-cpt} + +Raw image+text continued pretraining — no chat template, no conversational +scaffolding. The model learns to emit raw text conditioned on visual patches. +Intended for use cases like OCR/transcription corpora where every row is a +tight `(image, target_text)` pair and any user/assistant framing would pollute +the learned signal. + +### Dataset format (JSONL) + +Two keys per row: `text` (the raw string) and `images` (list of local paths). +The `text` must contain the model's placeholder token **once per image**, +placed immediately before the text it describes, followed by a newline: + +```json +{"text": "\nInvoice number: 10427. Total due: 148.32 USD.", "images": ["/dataset/crops/doc_14_p2.png"]} +{"text": "\nChapter 3 begins with a description of the storm over the harbor.", "images": ["/dataset/crops/doc_14_p3.png"]} +``` + +Notes: + +- Never wrap the row in `User:` / `Assistant:` / `Transcribe this:` scaffolding. That is the whole point of the CPT path. +- Do not manually append an EOS token. Axolotl appends one during tokenization. +- The newline between the placeholder and the real text preserves the BPE + boundary. Without it, some tokenizers merge the visual-token boundary with + the first real character. + +### The placeholder token varies by model + +| Model family | Placeholder | Notes | +|---|---|---| +| LLaVA-1.5 / 1.6 | `` | | +| SmolVLM / SmolVLM2 / Idefics3 | `` | Processor expands to 1088 tokens (17 tiles × 64) | +| Qwen2-VL / Qwen2.5-VL / Qwen3-VL | `<\|image_pad\|>` | Processor autowraps with `<\|vision_start\|>` / `<\|vision_end\|>` | +| Gemma-3 | `` | Processor expands to 256 `` | +| Gemma-4 | `<\|image\|>` | Processor expands to 256 `<\|image\|>` | + +Axolotl autodetects the placeholder from the loaded processor. If autodetection +fails, supply `image_token: ` on the dataset entry. + +### YAML example + +```yaml +base_model: HuggingFaceTB/SmolVLM-500M-Instruct +processor_type: AutoProcessor + +pretraining_dataset: + - path: /path/to/shards/*.jsonl + ds_type: json + type: multimodal_pretrain + text_column: text + image_column: images + image_base_dir: /path/to/images # optional, for relative paths + # image_token: "" # optional override; autodetect by default + +streaming: true +sequence_len: 2048 +sample_packing: false # REQUIRED — see below +remove_unused_columns: false # auto-set by validator + +max_steps: 10000 +micro_batch_size: 1 +gradient_accumulation_steps: 8 +``` + +### Eval datasets + +`test_datasets` accepts multimodal entries (`type: multimodal_pretrain` or +`multimodal: true`). Per-entry `text_column` and `image_column` are honored +independently. When more than one multimodal entry is provided, +`image_base_dir` and `image_token` must be either unset on every entry or +identical across them, because the eval collator resolves both once for the +merged eval stream. + +### Gates and rejections + +The following combinations are rejected at config-load time with a clear error: + +- `sample_packing: true` — cross-row packing would break the 1-to-1 alignment + between text placeholders and `pixel_values`. +- `chat_template` set to anything — defeats the purpose of the CPT path. +- `processor_type` unset — no processor means no image tensors. +- Multiple MM `test_datasets` entries with mismatched `image_base_dir` or `image_token`. + +In addition, the following model families are **not supported** in v1 and will +be rejected when their processor is loaded: + +- **Llama-3.2-Vision (Mllama)** — uses cross-attention image injection, not + in-stream placeholders. Use chat-template SFT. +- **Pixtral** — requires `mistral_common` and a different API. +- **InternVL** — ships a custom processor that doesn't produce `pixel_values`. + +Per-row validation: at encode time the row's text is tokenized once and the +number of `image_token_id` occurrences in the resulting token-id list must +equal `len(images)`. Counting by token id (not by substring) avoids false +matches — e.g., `` would substring-match inside ``. +This is a critical guardrail — LLaVA and Qwen-VL processors silently +accept rows without placeholders and drop the image, which looks like +successful training but teaches nothing. If a row fails this check, +inspect the tokenized ids rather than the raw string. + +### Why masking image tokens in labels is automatic + +For this multimodal CPT path, the collator masks every image-family token id +(``, `<\|image_pad\|>`, `<\|vision_start\|>`, `<\|vision_end\|>`, +``, ``, ``, `<\|image\|>`, +etc.) to `-100` in the labels tensor. + +Supported processors expand image placeholders into vision-specific token ids. +If those ids contribute to loss, the model is trained to predict patch or +marker tokens rather than only the target text. On architectures like +Qwen-VL and SmolVLM, leaving them unmasked substantially increases loss and can +destabilize training. + ## FAQ 1. `PIL.UnidentifiedImageError: cannot identify image file ...` diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 985414053c..b9a6009ad7 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -44,12 +44,34 @@ V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.collators.mm_pretrain import MultiModalPretrainDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +def _is_multimodal_cpt(cfg) -> bool: + if not getattr(cfg, "pretraining_dataset", None): + return False + ds_first = cfg.pretraining_dataset[0] + ds_type = None + mm_flag = None + if hasattr(ds_first, "type"): + ds_type = getattr(ds_first, "type", None) + mm_flag = getattr(ds_first, "multimodal", None) + elif isinstance(ds_first, dict): + ds_type = ds_first.get("type") + mm_flag = ds_first.get("multimodal") + return (ds_type == "multimodal_pretrain") or bool(mm_flag) + + +def _mm_cpt_get(pt_cfg, key, default=None): + if isinstance(pt_cfg, dict): + return pt_cfg.get(key, default) + return getattr(pt_cfg, key, default) + + class HFCausalTrainerBuilder(TrainerBuilderBase): """ Build the HuggingFace training args/trainer for causal models and reward modeling @@ -449,6 +471,31 @@ def build(self, total_num_steps): return trainer + def _build_mm_pretrain_collator(self, pad_to_multiple_of=None, is_eval=False): + from axolotl.prompt_strategies.multimodal_pretrain import ( + build_image_token_spec, + ) + + if is_eval and self.cfg.test_datasets: + pt_cfg = self.cfg.test_datasets[0] + elif self.cfg.pretraining_dataset: + pt_cfg = self.cfg.pretraining_dataset[0] + else: + pt_cfg = {} + spec = build_image_token_spec( + self.processor, override=_mm_cpt_get(pt_cfg, "image_token") + ) + collator_kwargs = { + "tokenizer": self.tokenizer, + "processor": self.processor, + "image_token_spec": spec, + "image_base_dir": _mm_cpt_get(pt_cfg, "image_base_dir"), + "max_length": self.cfg.sequence_len, + } + if pad_to_multiple_of is not None: + collator_kwargs["pad_to_multiple_of"] = pad_to_multiple_of + return MultiModalPretrainDataCollator(**collator_kwargs) + def build_collator( self, training_args, # type: "AxolotlTrainingArguments" # type: ignore @@ -456,6 +503,16 @@ def build_collator( **kwargs, ): if training_args.pretraining: + # Intercept MM CPT before the text-only pretraining branches. + if ( + self.cfg.processor_type + and self.processor + and _is_multimodal_cpt(self.cfg) + ): + return self._build_mm_pretrain_collator( + pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), + is_eval=is_eval, + ) if ( self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1 @@ -517,6 +574,15 @@ def build_collator( else: collator = BatchSamplerDataCollatorForSeq2Seq else: + if ( + self.cfg.processor_type + and self.processor + and _is_multimodal_cpt(self.cfg) + ): + return self._build_mm_pretrain_collator( + pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), + is_eval=is_eval, + ) if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator # Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg. diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py new file mode 100644 index 0000000000..c716c63aa3 --- /dev/null +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -0,0 +1,297 @@ +"""Multimodal CPT tokenization strategy.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from transformers import BatchEncoding, PreTrainedTokenizerBase, ProcessorMixin + +from axolotl.prompt_strategies.pretrain import ( + PretrainTokenizationStrategy, + PretrainTokenizer, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _get_incompatible_processor_classes() -> tuple[type, ...]: + classes: list[type] = [] + for mod_path, name in ( + ("transformers.models.mllama", "MllamaProcessor"), + ("transformers.models.pixtral", "PixtralProcessor"), + ("transformers.models.internvl", "InternVLProcessor"), + ): + try: + import importlib + + mod = importlib.import_module(mod_path) + cls = getattr(mod, name, None) + if cls is not None: + classes.append(cls) + except ImportError: + continue + return tuple(classes) + + +_KNOWN_IMAGE_TOKEN_CANDIDATES: tuple[str, ...] = ( + "", + "<|image|>", + "<|image_pad|>", + "", + "", + "[IMG]", + "", +) + +# Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM. +_IMAGE_FAMILY_TOKEN_CANDIDATES: tuple[str, ...] = ( + "", + "<|image|>", + "<|image_pad|>", + "", + "", + "", + "<|vision_start|>", + "<|vision_end|>", + "[IMG]", + "[IMG_END]", + "", +) + +_INCOMPATIBLE_PROCESSOR_REASONS: dict[str, str] = { + "MllamaProcessor": ( + "Llama-3.2-Vision (Mllama) uses cross-attention image injection, not " + "in-stream placeholder tokens. Multimodal CPT is incompatible with " + "this architecture; use chat-template SFT instead." + ), + "PixtralProcessor": ( + "Pixtral's tokenizer goes through mistral_common with a different " + "API surface than AutoProcessor. Multimodal CPT not supported in v1; " + "use chat-template SFT or Mistral-Small-3.1." + ), + "InternVLProcessor": ( + "InternVL ships a custom processing pipeline (AutoProcessor returns " + "text-only); no pixel_values are produced. Multimodal CPT not " + "supported in v1." + ), +} +_INCOMPATIBLE_PROCESSOR_CLASSES = _get_incompatible_processor_classes() + + +@dataclass +class ImageTokenSpec: + image_token: str + image_token_id: int + image_family_token_ids: set[int] + + +def build_image_token_spec( + processor: ProcessorMixin, override: str | None = None +) -> ImageTokenSpec: + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is None: + raise ValueError( + "Processor has no `tokenizer` attribute — multimodal CPT " + "requires a processor with a text tokenizer (e.g. one produced " + "by AutoProcessor.from_pretrained for a VLM)." + ) + + def resolve_id(tok: str) -> int | None: + tid = tokenizer.convert_tokens_to_ids(tok) + unk = getattr(tokenizer, "unk_token_id", None) + if tid is None or tid == unk: + return None + return tid + + known_special_tokens: set[str] = set() + try: + known_special_tokens |= set(tokenizer.get_added_vocab().keys()) + except Exception: + pass + known_special_tokens |= set(getattr(tokenizer, "all_special_tokens", None) or []) + known_special_tokens |= set( + getattr(tokenizer, "additional_special_tokens", None) or [] + ) + + image_token: str | None = None + image_token_id: int | None = None + if override is not None: + # Reject plain words that BPE-tokenize cleanly but aren't placeholders. + if override not in known_special_tokens: + raise ValueError( + f"image_token override {override!r} is not a registered " + f"special token on this tokenizer. Pick one of the model's " + f"actual image tokens (e.g. '', '<|image_pad|>', " + f"''), or leave unset to autodetect." + ) + image_token_id = resolve_id(override) + if image_token_id is None: + raise ValueError( + f"image_token override {override!r} did not resolve to a " + f"token id (unk). Remove the override to autodetect." + ) + image_token = override + else: + proc_token = getattr(processor, "image_token", None) + if proc_token is not None: + image_token_id = resolve_id(proc_token) + if image_token_id is not None: + image_token = proc_token + if image_token is None: + for cand in _KNOWN_IMAGE_TOKEN_CANDIDATES: + tid = resolve_id(cand) + if tid is not None: + image_token = cand + image_token_id = tid + break + if image_token is None: + raise ValueError( + "Could not autodetect the image placeholder token for this " + "processor. Set `image_token: ` in the dataset config " + "(e.g. '' for LLaVA, '<|image_pad|>' for Qwen-VL, " + "'' for Gemma-3)." + ) + + # Filter to registered tokens so BPE-fallback ids don't get masked. + family: set[int] = {image_token_id} # type: ignore[arg-type] + for cand in _IMAGE_FAMILY_TOKEN_CANDIDATES: + if cand != image_token and cand not in known_special_tokens: + continue + tid = resolve_id(cand) + if tid is not None: + family.add(tid) + return ImageTokenSpec( + image_token=image_token, + image_token_id=image_token_id, # type: ignore[arg-type] + image_family_token_ids=family, + ) + + +def check_processor_compatibility(processor: ProcessorMixin) -> None: + if _INCOMPATIBLE_PROCESSOR_CLASSES and isinstance( + processor, _INCOMPATIBLE_PROCESSOR_CLASSES + ): + for cls in _INCOMPATIBLE_PROCESSOR_CLASSES: + if isinstance(processor, cls): + raise ValueError( + f"Multimodal CPT is not supported for {cls.__name__}: " + f"{_INCOMPATIBLE_PROCESSOR_REASONS.get(cls.__name__, '')}" + ) + # MRO-name fallback for test fakes and unimportable concrete classes. + for base_cls in type(processor).__mro__: + reason = _INCOMPATIBLE_PROCESSOR_REASONS.get(base_cls.__name__) + if reason is not None: + raise ValueError( + f"Multimodal CPT is not supported for {base_cls.__name__}: {reason}" + ) + + +class MultimodalPretrainTokenizationStrategy(PretrainTokenizationStrategy): + def __init__( + self, + *args: Any, + image_token: str, + image_token_id: int, + image_column: str = "images", + image_base_dir: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.image_token = image_token + self.image_token_id = image_token_id + self.image_column = image_column + self.image_base_dir = image_base_dir + + def _tokenize( + self, + prompt: str, + add_eos_token: bool = True, + strip_bos_token: bool = False, + ) -> BatchEncoding: + # No truncation: collator re-tokenizes the full text without truncation; + # truncating here decouples the stored ids from what the model receives. + res = self.tokenizer(prompt, add_special_tokens=True) + res["input_ids"] = [res["input_ids"] + [self.tokenizer.eos_token_id]] + res["attention_mask"] = [res["attention_mask"] + [1]] + return res + + def tokenize_prompt(self, prompt: dict[str, Any]) -> dict[str, list]: + text = prompt[self.text_column] + raw_images = prompt.get(self.image_column) + if raw_images is None: + images: list = [] + elif isinstance(raw_images, (list, tuple)): + images = list(raw_images) + else: + raise ValueError( + f"Row's `{self.image_column}` must be a list of image paths, " + f"got {type(raw_images).__name__}." + ) + + res = self._tokenize(text) + ids = res["input_ids"][0] + # Count by token id — `text.count` substring-matches `` in ``. + n_placeholders = sum(1 for t in ids if t == self.image_token_id) + if n_placeholders != len(images): + raise ValueError( + f"Multimodal CPT row has {n_placeholders} occurrence(s) of " + f"{self.image_token!r} in text but {len(images)} image path(s) " + f"in `{self.image_column}`. They must match — the text column " + f"must contain exactly one placeholder per image." + ) + if len(ids) > self.max_length: + raise ValueError( + f"Multimodal CPT row tokenizes to {len(ids)} tokens which " + f"exceeds sequence_len={self.max_length}. Pre-chunk your text " + f"or raise sequence_len." + ) + + n_chunks = len(res["input_ids"]) + res["images"] = [list(images)] * n_chunks + res["_mm_text"] = [text] * n_chunks + return res + + +def load( + tokenizer: PreTrainedTokenizerBase, + cfg: Any, + ds_cfg: dict | None = None, + processor: ProcessorMixin | None = None, +) -> MultimodalPretrainTokenizationStrategy: + if processor is None: + raise ValueError( + "multimodal_pretrain requires a processor. Set `processor_type: " + "AutoProcessor` (or the concrete processor class) in your config " + "so axolotl loads it at startup." + ) + check_processor_compatibility(processor) + + ds_cfg = dict(ds_cfg or {}) + text_column = ds_cfg.get("text_column") or ds_cfg.get("field") or "text" + image_column = ds_cfg.get("image_column") or "images" + image_base_dir = ds_cfg.get("image_base_dir") + image_token_override = ds_cfg.get("image_token") + + spec = build_image_token_spec(processor, override=image_token_override) + LOG.info( + f"multimodal_pretrain: placeholder={spec.image_token!r} " + f"(id={spec.image_token_id}), masking {len(spec.image_family_token_ids)} " + f"image-family token ids in labels" + ) + + strat = MultimodalPretrainTokenizationStrategy( + PretrainTokenizer(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + text_column=text_column, + image_column=image_column, + image_base_dir=image_base_dir, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + max_length=cfg.sequence_len, + ) + strat.image_token_spec = spec # type: ignore[attr-defined] + return strat diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py new file mode 100644 index 0000000000..f6870434f4 --- /dev/null +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -0,0 +1,306 @@ +"""Collator for multimodal CPT.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, Union + +from PIL import Image +from torch import Tensor +from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers.data.data_collator import DataCollatorMixin +from transformers.utils import PaddingStrategy + +from axolotl.prompt_strategies.multimodal_pretrain import ( + ImageTokenSpec, + check_processor_compatibility, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +# Decompression-bomb cap (~7070×7070). +_DEFAULT_MAX_IMAGE_PIXELS = 50_000_000 +_DEFAULT_MAX_IMAGES_PER_ROW = 32 + + +@dataclass +class MultiModalPretrainDataCollator(DataCollatorMixin): + tokenizer: PreTrainedTokenizerBase + processor: ProcessorMixin + image_token_spec: ImageTokenSpec + image_base_dir: Optional[str] = None + return_tensors: Literal["pt"] = "pt" + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + max_length: Optional[int] = None + skip_bad_images: bool = False + max_image_pixels: int = _DEFAULT_MAX_IMAGE_PIXELS + max_images_per_row: int = _DEFAULT_MAX_IMAGES_PER_ROW + + _image_family_token_ids: set[int] = field(init=False, default_factory=set) + _base_dir_real: Optional[str] = field(init=False, default=None) + + def __post_init__(self) -> None: + if self.return_tensors != "pt": + raise ValueError( + "MultiModalPretrainDataCollator only supports " + "return_tensors='pt' (in-place torch ops are used downstream)." + ) + check_processor_compatibility(self.processor) + self._image_family_token_ids = set(self.image_token_spec.image_family_token_ids) + if self.image_base_dir is not None: + self._base_dir_real = os.path.realpath(self.image_base_dir) + + def _resolve_image_path(self, p: str) -> str: + if not isinstance(p, str): + raise ValueError(f"Image path must be str, got {type(p).__name__}.") + if "\x00" in p: + raise ValueError("Image path contains embedded NUL byte.") + p_lower = p.lower() + if p_lower.startswith( + ("http://", "https://", "ftp://", "ftps://", "file://", "data:") + ) or p.startswith(("\\\\", "//")): + raise ValueError( + f"Non-local image path scheme is not supported in v1 " + f"multimodal CPT (got {p!r})." + ) + if self._base_dir_real is not None: + if os.path.isabs(p): + raise ValueError( + f"Absolute image path {p!r} is rejected when " + f"`image_base_dir` is configured. All image paths must be " + f"relative to the configured base directory." + ) + resolved = os.path.realpath(os.path.join(self._base_dir_real, p)) + # commonpath (not startswith) so root-dir bases like "/" work. + try: + within_base = ( + os.path.commonpath([self._base_dir_real, resolved]) + == self._base_dir_real + ) + except ValueError: + within_base = False + if not within_base: + raise ValueError( + f"Image path {p!r} resolves outside `image_base_dir` " + f"after symlink resolution. Refusing to load." + ) + return resolved + return os.path.realpath(p) if os.path.isabs(p) else p + + def _open_image_hardened(self, resolved: str) -> Image.Image: + # O_NOFOLLOW closes the realpath→open TOCTOU window for the final component. + nofollow = getattr(os, "O_NOFOLLOW", 0) + try: + fd = os.open(resolved, os.O_RDONLY | nofollow) + except OSError as exc: + raise ValueError( + f"Cannot open image (os.open failed: {type(exc).__name__})." + ) from exc + file_obj = os.fdopen(fd, "rb") + try: + with Image.open(file_obj) as src: + w, h = src.size + if w * h > self.max_image_pixels: + raise ValueError( + f"Image pixels ({w}×{h}) exceed " + f"max_image_pixels ({self.max_image_pixels})." + ) + # Multi-frame bomb guard (GIF/TIFF/WebP). + n_frames = getattr(src, "n_frames", 1) + if n_frames > 1: + raise ValueError( + f"Multi-frame images are not supported (got {n_frames} frames)." + ) + img = src.convert("RGB") + img.load() + return img + finally: + if not file_obj.closed: + file_obj.close() + + def _load_images_for_row( + self, paths: list[str], row_index: int + ) -> list[Image.Image]: + if len(paths) > self.max_images_per_row: + raise ValueError( + f"Row {row_index}: {len(paths)} images exceeds " + f"`max_images_per_row={self.max_images_per_row}`. Split the " + f"row or raise the cap if this is expected." + ) + out: list[Image.Image] = [] + for raw in paths: + try: + resolved = self._resolve_image_path(raw) + img = self._open_image_hardened(resolved) + except Exception as exc: + # Top-level log gets basename only; full path stays on DEBUG. + basename = os.path.basename(str(raw)) + msg = ( + f"Row {row_index}: failed to load image {basename!r} " + f"({type(exc).__name__})" + ) + LOG.debug("failed image full path: %r; error: %s", raw, exc) + if self.skip_bad_images: + LOG.warning("%s — skipping", msg) + continue + raise RuntimeError(msg) from exc + out.append(img) + return out + + def torch_call(self, examples: list[dict]) -> dict[str, Any]: + if not examples: + raise ValueError("Empty batch passed to MultiModalPretrainDataCollator.") + + texts: list[str] = [] + images: list[list[Image.Image]] = [] + for i, ex in enumerate(examples): + if "_mm_text" not in ex or "images" not in ex: + raise KeyError( + f"MultiModalPretrainDataCollator: row {i} is missing " + f"'_mm_text' or 'images'. Did you wire the multimodal CPT " + f"encoder (encode_streaming_multimodal or " + f"MultimodalPretrainTokenizationStrategy)?" + ) + mm_text = ex["_mm_text"] + if not isinstance(mm_text, str): + raise TypeError( + f"Row {i}: `_mm_text` must be str, got " + f"{type(mm_text).__name__}. Check dataset encoding " + f"(Parquet BINARY columns may surface as bytes)." + ) + raw = ex["images"] + if raw is None: + raw_paths: list[str] = [] + elif isinstance(raw, (list, tuple)): + raw_paths = list(raw) + else: + raise TypeError( + f"Row {i}: `images` must be a list (or None), got " + f"{type(raw).__name__}." + ) + for j, rp in enumerate(raw_paths): + if not isinstance(rp, str): + raise TypeError( + f"Row {i}, image {j}: path must be str, got " + f"{type(rp).__name__}." + ) + texts.append(mm_text) + loaded = self._load_images_for_row(raw_paths, row_index=i) + if self.skip_bad_images and len(loaded) != len(raw_paths): + # Drop the row to avoid silent placeholder/image count mismatch. + LOG.warning( + "Row %d: %d/%d images failed to load; dropping row.", + i, + len(raw_paths) - len(loaded), + len(raw_paths), + ) + texts.pop() + continue + images.append(loaded) + + if not texts: + raise RuntimeError( + "All rows in the batch were dropped due to image load " + "failures. Check dataset integrity." + ) + + # All-text batch: bypass the processor and tokenize directly. + if all(len(im) == 0 for im in images): + LOG.debug( + "MultiModalPretrainDataCollator: all-text batch (%d rows); " + "using tokenizer-only fallback (no pixel_values).", + len(texts), + ) + tok_kwargs: dict[str, Any] = { + "text": texts, + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + tok_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + batch = self.tokenizer(**tok_kwargs) + tok_input_ids: Tensor = batch["input_ids"] + tok_labels = tok_input_ids.clone() + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + tok_labels[tok_labels == pad_id] = -100 + for tid in self._image_family_token_ids: + tok_labels[tok_labels == tid] = -100 + batch["labels"] = tok_labels + return dict(batch) + + # No truncation: it chops input_ids mid-placeholder while pixel_values + # keep every image — silent text/pixel mismatch. We warn post-hoc instead. + proc_kwargs: dict[str, Any] = { + "text": texts, + "images": images, + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + proc_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + try: + batch = self.processor(**proc_kwargs) + except Exception as exc: + # Pinpoint the bad row; bail to "inconclusive" if retry raises a different class. + offender_idx: Optional[int] = None + retry_ok = True + retry_kwargs: dict[str, Any] = { + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + retry_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + for i, (t, imgs) in enumerate(zip(texts, images, strict=True)): + try: + self.processor(text=[t], images=[imgs], **retry_kwargs) + except Exception as retry_exc: + if isinstance(retry_exc, type(exc)) or isinstance( + exc, type(retry_exc) + ): + offender_idx = i + else: + retry_ok = False + break + if offender_idx is not None: + location = f"row {offender_idx}" + elif retry_ok: + location = ( + f"batch of {len(texts)} rows " + f"(individual rows all succeed; see __cause__ for details)" + ) + else: + location = f"batch of {len(texts)} rows (retry inconclusive)" + raise RuntimeError( + f"MultiModalPretrainDataCollator: processor call failed on " + f"{location} ({type(exc).__name__}: {exc}). Common causes: " + f"placeholder token absent from the row's text, image count " + f"mismatch, or an unsupported processor class." + ) from exc + + input_ids_len = batch["input_ids"].shape[-1] + if self.max_length is not None and input_ids_len > self.max_length: + LOG.warning( + "Batch input_ids length %d exceeds configured sequence_len %d " + "(image placeholder expansion). Reduce max_images_per_row or " + "raise sequence_len if this fires repeatedly.", + input_ids_len, + self.max_length, + ) + + input_ids: Tensor = batch["input_ids"] + labels = input_ids.clone() + + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + + # Without this, image-family ids dominate loss and blow it up ~10×. + for tid in self._image_family_token_ids: + labels[labels == tid] = -100 + + batch["labels"] = labels + return batch diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 0b2ec2b5fb..809d08ac5f 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -10,6 +10,7 @@ DatasetDict, IterableDataset, IterableDatasetDict, + concatenate_datasets, load_dataset, ) from transformers import PreTrainedTokenizer, ProcessorMixin @@ -134,7 +135,9 @@ def _prepare_streaming_dataset( """ if cfg.pretraining_dataset: dataset_config = _extract_pretraining_config(cfg) - train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + train_dataset = _load_streaming_dataset( + dataset_config, cfg, tokenizer, processor=processor + ) elif cfg.sample_packing: # TODO(djsaunde): Implement for multiple datasets dataset_config = DictDefault(cfg.datasets[0]) @@ -142,7 +145,9 @@ def _prepare_streaming_dataset( # Ensure we have a split set - default to 'train' if not specified if not hasattr(dataset_config, "split") or not dataset_config.split: dataset_config.split = "train" - train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + train_dataset = _load_streaming_dataset( + dataset_config, cfg, tokenizer, processor=processor + ) else: # Use legacy loading function for non-packed streaming datasets train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( @@ -160,35 +165,73 @@ def _prepare_streaming_dataset( # Load evaluation dataset if specified eval_dataset = None if cfg.test_datasets: - _, eval_dataset, _ = _load_and_prepare_datasets( - tokenizer, - cfg, - split="test", - processor=processor, - streaming=False, + test_dicts = [t if isinstance(t, dict) else dict(t) for t in cfg.test_datasets] + is_mm_cpt_eval = any( + t.get("type") == "multimodal_pretrain" or bool(t.get("multimodal")) + for t in test_dicts ) + if is_mm_cpt_eval: + eval_streams = [] + for entry in test_dicts: + if not ( + entry.get("type") == "multimodal_pretrain" + or bool(entry.get("multimodal")) + ): + raise ValueError( + "Mixing multimodal and non-multimodal entries in " + "`test_datasets` is not supported. All eval entries " + "must be MM (type: multimodal_pretrain or " + "multimodal: true) when training is MM CPT." + ) + eval_config = _pretraining_config_from_entry(entry) + eval_streams.append( + _load_streaming_dataset( + eval_config, cfg, tokenizer, processor=processor + ) + ) + eval_dataset = ( + eval_streams[0] + if len(eval_streams) == 1 + else concatenate_datasets(eval_streams) + ) + else: + _, eval_dataset, _ = _load_and_prepare_datasets( + tokenizer, + cfg, + split="test", + processor=processor, + streaming=False, + ) # For streaming, we return max_steps directly from config or -1 if not set total_num_steps = cfg.max_steps if cfg.max_steps else -1 return train_dataset, eval_dataset, total_num_steps, [] +def _pretraining_config_from_entry(entry: dict) -> DictDefault: + return DictDefault( + { + "path": entry["path"], + "name": entry.get("name"), + "skip": entry.get("skip"), + "split": entry.get("split", "train"), + "data_files": entry.get("data_files"), + "type": entry.get("type", "pretrain"), + "text_column": entry.get("text_column", "text"), + "multimodal": entry.get("multimodal"), + "image_column": entry.get("image_column", "images"), + "image_base_dir": entry.get("image_base_dir"), + "image_token": entry.get("image_token"), + } + ) + + def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: """Extract pretraining configuration from the main config.""" if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): - config = cfg.pretraining_dataset[0] - return DictDefault( - { - "path": config["path"], - "name": config["name"], - "skip": config["skip"], - "split": config.get("split", "train"), - "data_files": config.get("data_files"), - "type": config.get("type", "pretrain"), - } - ) + return _pretraining_config_from_entry(cfg.pretraining_dataset[0]) # Simple string path case return DictDefault( { @@ -198,12 +241,20 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: "split": "train", "data_files": None, "type": "pretrain", + "text_column": "text", + "multimodal": None, + "image_column": "images", + "image_base_dir": None, + "image_token": None, # nosec } ) def _load_streaming_dataset( - pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer + pretraining_config: DictDefault, + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None = None, ) -> IterableDataset: """Load and prepare a streaming dataset for pretraining.""" # Create dataset wrapper partial function @@ -213,6 +264,7 @@ def _load_streaming_dataset( tokenizer=tokenizer, cfg=cfg, dataset_base_type=pretraining_config["type"], + processor=processor, ) # Load the actual dataset @@ -221,7 +273,7 @@ def _load_streaming_dataset( and cfg.accelerator_config.dispatch_batches and not is_local_main_process() ): - iter_dataset = _create_placeholder_dataset() + iter_dataset = _create_placeholder_dataset(pretraining_config) else: iter_dataset = load_dataset( pretraining_config["path"], @@ -242,19 +294,39 @@ def _load_streaming_dataset( tokenizer, cfg, dataset_wrapper_partial, + processor=processor, + pretraining_config=pretraining_config, ) # Format for PyTorch return train_dataset.with_format("torch") -def _create_placeholder_dataset() -> IterableDataset: +def _create_placeholder_dataset( + pretraining_config: DictDefault | None = None, +) -> IterableDataset: """Create a minimal placeholder dataset for non-main processes.""" - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write("text\n") - f.write("lorem ipsum dolor sit amet\n") - f.seek(0) - return load_dataset("csv", data_files=f.name, split="train", streaming=True) + text_column = "text" + image_column: str | None = None + if pretraining_config is not None: + text_column = pretraining_config.get("text_column") or "text" + is_mm = pretraining_config.get("type") == "multimodal_pretrain" or bool( + pretraining_config.get("multimodal") + ) + if is_mm: + image_column = pretraining_config.get("image_column") or "images" + + if image_column is None: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write(f"{text_column}\n") + f.write("lorem ipsum dolor sit amet\n") + f.seek(0) + return load_dataset("csv", data_files=f.name, split="train", streaming=True) + + def _gen(): + yield {text_column: "lorem ipsum dolor sit amet", image_column: []} + + return IterableDataset.from_generator(_gen) def _load_tokenized_prepared_datasets( diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py index 8b6b8a439b..966fa65719 100644 --- a/src/axolotl/utils/data/streaming.py +++ b/src/axolotl/utils/data/streaming.py @@ -7,7 +7,7 @@ import torch from datasets import Dataset from torch.utils.data import RandomSampler -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, ProcessorMixin from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.logging import get_logger @@ -176,11 +176,93 @@ def encode_streaming( return ret +def encode_streaming_multimodal( + examples: Dict[str, List], + tokenizer: PreTrainedTokenizerBase, + max_tokens: int, + image_token: str, + image_token_id: int, + text_column: str = "text", + image_column: str = "images", +) -> Dict[str, List]: + texts: List[str] = examples[text_column] + imgs_list: List[List[str]] = examples[image_column] + + if len(texts) != len(imgs_list): + raise ValueError( + f"encode_streaming_multimodal: text column has {len(texts)} rows " + f"but image column has {len(imgs_list)}" + ) + + input_ids: List[List[int]] = [] + labels: List[List[int]] = [] + attention_mask: List[List[int]] = [] + keep_images: List[List[str]] = [] + keep_text: List[str] = [] + + for text, imgs in zip(texts, imgs_list, strict=True): + if not isinstance(text, str): + raise TypeError( + f"encode_streaming_multimodal: `{text_column}` must be str, " + f"got {type(text).__name__}." + ) + if imgs is None: + imgs = [] + if not isinstance(imgs, (list, tuple)): + raise ValueError( + f"encode_streaming_multimodal: row's `{image_column}` must be " + f"a list; got {type(imgs).__name__}" + ) + for j, ip in enumerate(imgs): + if not isinstance(ip, str): + raise TypeError( + f"encode_streaming_multimodal: image {j} in row must be " + f"str, got {type(ip).__name__}." + ) + # No truncation: counting on truncated ids and storing untruncated text + # (which the collator re-tokenizes without truncation) silently produces + # oversize batches and confusing placeholder/image-count mismatches. + enc = tokenizer(text, add_special_tokens=True) + ids = list(enc["input_ids"]) + [tokenizer.eos_token_id] + mask = list(enc["attention_mask"]) + [1] + # Count by id — `text.count` substring-matches `` in ``. + n_placeholders = sum(1 for t in ids if t == image_token_id) + if n_placeholders != len(imgs): + raise ValueError( + f"Multimodal CPT row has {n_placeholders} occurrence(s) of " + f"{image_token!r} in text but {len(imgs)} image path(s). " + f"Text and image count must match (one placeholder per image)." + ) + if len(ids) > max_tokens: + raise ValueError( + f"Multimodal CPT row tokenizes to {len(ids)} tokens which " + f"exceeds sequence_len={max_tokens}. Pre-chunk your text or " + f"raise sequence_len (image patch expansion at the processor " + f"may push the final length even higher)." + ) + # Labels = ids; collator masks image-family ids after re-tokenization. + input_ids.append(ids) + labels.append(list(ids)) + attention_mask.append(mask) + keep_images.append(list(imgs)) + keep_text.append(text) + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "images": keep_images, + "_mm_text": keep_text, + } + + def wrap_streaming_dataset( dataset, tokenizer, cfg, ds_wrapper_fn, + processor: Optional[ProcessorMixin] = None, + pretraining_config=None, ): if cfg.sample_packing: # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure @@ -213,17 +295,66 @@ def wrap_streaming_dataset( # NOTE: This is not reachable for SFT datasets since we use the pre-existing # loading function for non-packed streaming datasets. Refer to # _prepare_streaming_datasets in sft.py for that code path. - text_column = ( - getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text" + # Prefer the resolved per-entry config so eval (test_datasets) doesn't + # silently inherit the training entry's columns/image_token. + if pretraining_config is not None: + ds_first = pretraining_config + elif cfg.pretraining_dataset: + ds_first = cfg.pretraining_dataset[0] + else: + ds_first = {} + # Plain dicts need `.get`; pydantic/DictDefault need `getattr`. + get_ds_value = ( + ds_first.get + if isinstance(ds_first, dict) + else lambda key, default=None: getattr(ds_first, key, default) ) - encode = functools.partial( - encode_streaming, - tokenizer=tokenizer, - max_tokens=cfg.sequence_len, - text_column=text_column, - concatenate=cfg.pretraining_sample_concatenation is True, + text_column = get_ds_value("text_column", "text") or "text" + ds_type = (get_ds_value("type", None) or "").strip() + is_mm_cpt = ds_type == "multimodal_pretrain" or bool( + get_ds_value("multimodal", False) ) + if is_mm_cpt: + if processor is None: + raise ValueError( + "Multimodal CPT (type: multimodal_pretrain) requires a " + "processor. Set `processor_type: AutoProcessor` (or the " + "concrete processor class) in your config." + ) + from axolotl.prompt_strategies.multimodal_pretrain import ( + build_image_token_spec, + check_processor_compatibility, + ) + + check_processor_compatibility(processor) + spec = build_image_token_spec( + processor, + override=get_ds_value("image_token", None), + ) + image_column = get_ds_value("image_column", None) or "images" + LOG.info( + f"multimodal streaming CPT: placeholder={spec.image_token!r} " + f"(id={spec.image_token_id})" + ) + encode = functools.partial( + encode_streaming_multimodal, + tokenizer=tokenizer, + max_tokens=cfg.sequence_len, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + text_column=text_column, + image_column=image_column, + ) + else: + encode = functools.partial( + encode_streaming, + tokenizer=tokenizer, + max_tokens=cfg.sequence_len, + text_column=text_column, + concatenate=cfg.pretraining_sample_concatenation is True, + ) + if cfg.shuffle_merged_datasets: dataset = dataset.shuffle( seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 73e9f88230..24b42994a5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -22,6 +22,7 @@ DatasetConfig, DPODataset, KTODataset, + MultiModalEvalDataset, PretrainingDataset, SFTDataset, StepwiseSupervisedDataset, @@ -353,7 +354,8 @@ class AxolotlInputConfig( test_datasets: ( Annotated[ list[ - SFTDataset + MultiModalEvalDataset + | SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 97ed71631d..6c55a6a9d9 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -238,6 +238,88 @@ class PretrainingDataset(BaseModel): data_files: str | None = None skip: int | None = None + # Multimodal CPT fields. Opt-in via `type: multimodal_pretrain` or `multimodal: true`. + multimodal: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Opt in to multimodal CPT. Auto-enabled when type='multimodal_pretrain'." + }, + ) + image_column: str | None = Field( + default="images", + json_schema_extra={ + "description": "Column holding a list of image paths per row." + }, + ) + image_base_dir: str | None = Field( + default=None, + json_schema_extra={"description": "Base directory for relative image paths."}, + ) + image_token: str | None = Field( + default=None, + json_schema_extra={ + "description": "Override the image placeholder token (autodetected from processor if unset)." + }, + ) + + +class MultiModalEvalDataset(BaseModel): + """Multimodal CPT eval dataset configuration (test_datasets entry). + + Use type='multimodal_pretrain' (or multimodal=True). The dataset must + expose a text column and a list[str] image-paths column; their names + default to 'text' and 'images' and can be overridden per-entry. + """ + + path: str | None = None + name: str | None = None + split: str | None = "train" + data_files: str | list[str] | None = None + skip: int | None = None + type: str | None = None + trust_remote_code: bool | None = False + + multimodal: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Opt in to multimodal eval. Auto-enabled when type='multimodal_pretrain'." + }, + ) + text_column: str | None = Field( + default="text", + json_schema_extra={"description": "Column holding the row's text."}, + ) + image_column: str | None = Field( + default="images", + json_schema_extra={ + "description": "Column holding a list of image paths per row." + }, + ) + image_base_dir: str | None = Field( + default=None, + json_schema_extra={"description": "Base directory for relative image paths."}, + ) + image_token: str | None = Field( + default=None, + json_schema_extra={ + "description": "Override the image placeholder token (autodetected from processor if unset)." + }, + ) + + @model_validator(mode="before") + @classmethod + def _require_mm_markers(cls, data): + if isinstance(data, BaseModel): + data = data.model_dump() + if not isinstance(data, dict): + return data + if data.get("type") != "multimodal_pretrain" and not data.get("multimodal"): + raise ValueError( + "MultiModalEvalDataset requires type='multimodal_pretrain' " + "or multimodal=True" + ) + return data + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index a19320b9c0..6f78e31b63 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1351,6 +1351,93 @@ def check_streaming_w_multiple_datasets(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_multimodal_cpt(cls, data): + pd = data.get("pretraining_dataset") + if not pd: + return data + + pd_list = pd if isinstance(pd, list) else [pd] + + def _entry_is_mm(entry) -> bool: + if isinstance(entry, dict): + ds_type_ = entry.get("type") + mm_flag_ = entry.get("multimodal") + else: + ds_type_ = getattr(entry, "type", None) + mm_flag_ = getattr(entry, "multimodal", None) + return ds_type_ == "multimodal_pretrain" or bool(mm_flag_) + + # MM config resolves from entry[0] only; multi-entry runs miscollate or silently demote. + if len(pd_list) > 1 and any(_entry_is_mm(e) for e in pd_list): + raise ValueError( + "Multimodal CPT supports exactly one `pretraining_dataset` " + f"entry (found {len(pd_list)}). Image settings " + "(`image_base_dir`, `image_token`) and MM-mode detection " + "both resolve from entry[0] only, so additional entries " + "would be silently miscollated or drop their MM config. " + "Split multimodal CPT into its own run." + ) + + first = pd_list[0] + if not isinstance(first, dict): + return data + + ds_type = first.get("type") + is_mm_cpt = ds_type == "multimodal_pretrain" or bool(first.get("multimodal")) + if not is_mm_cpt: + return data + + if not data.get("processor_type"): + raise ValueError( + "Multimodal CPT (type: multimodal_pretrain) requires " + "`processor_type` to be set — e.g. `processor_type: AutoProcessor`. " + "Without a processor, images in the dataset cannot be turned " + "into pixel tensors." + ) + if data.get("sample_packing"): + raise ValueError( + "Multimodal CPT is incompatible with `sample_packing: true`. " + "Each image's placeholder token expands to a variable number " + "of patch tokens at the processor, so cross-row packing would " + "break the 1-to-1 alignment between text placeholders and " + "pixel_values. Set `sample_packing: false`." + ) + if data.get("chat_template"): + raise ValueError( + "Multimodal CPT (raw image+text pretraining) is incompatible " + "with `chat_template`. The point of the CPT path is to avoid " + "conversational scaffolding entirely. Remove `chat_template` " + "or switch to chat-template SFT." + ) + # Keep `images` and `_mm_text` columns alive for the collator. + prev_remove_unused = data.get("remove_unused_columns") + if prev_remove_unused is not False: + LOG.info( + "Auto-set `remove_unused_columns: false` for multimodal CPT " + "to preserve `images` and `_mm_text` columns (previous value: %r)", + prev_remove_unused, + ) + data["remove_unused_columns"] = False + + test_datasets = data.get("test_datasets") or [] + mm_test = [t for t in test_datasets if isinstance(t, dict) and _entry_is_mm(t)] + if len(mm_test) > 1: + for key in ("image_base_dir", "image_token"): + values = {t.get(key) for t in mm_test} + if len(values) > 1: + raise ValueError( + f"Multimodal CPT eval requires `{key}` to be either " + f"unset on all `test_datasets` entries or identical " + f"across them. The eval collator resolves " + f"`image_base_dir` and `image_token` once from the " + f"first entry, so heterogeneous values would silently " + f"miscollate later entries. Got: {sorted(map(str, values))}." + ) + + return data + class ModelCompatibilityValidationMixin: """Validation methods for specific model compatibility.""" diff --git a/tests/conftest.py b/tests/conftest.py index 16a01f8aa3..25912ced4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,6 +112,24 @@ def download_smollm2_135m_instruct_model(): snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M-Instruct", repo_type="model") +@pytest.fixture(scope="session", autouse=True) +def download_smolvlm_500m_instruct_model(): + # Processor/tokenizer only — skip ~1 GB of weight shards. + snapshot_download_w_retry( + "HuggingFaceTB/SmolVLM-500M-Instruct", + repo_type="model", + allow_patterns=[ + "*.json", + "*.txt", + "*.model", + "*.jinja", + "tokenizer*", + "vocab*", + "merges*", + ], + ) + + @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_gptq_model(): # download the model diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py new file mode 100644 index 0000000000..e20dd9a4ae --- /dev/null +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -0,0 +1,242 @@ +"""Multimodal CPT prompt strategy + safety gate tests.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +from PIL import Image +from transformers import AutoProcessor + +from axolotl.prompt_strategies.multimodal_pretrain import ( + _INCOMPATIBLE_PROCESSOR_REASONS, + ImageTokenSpec, + MultimodalPretrainTokenizationStrategy, + build_image_token_spec, + check_processor_compatibility, + load, +) +from axolotl.prompt_strategies.pretrain import PretrainTokenizer + +from tests.hf_offline_utils import enable_hf_offline + +_SMOLVLM = "HuggingFaceTB/SmolVLM-500M-Instruct" + + +@pytest.fixture(scope="module", name="smolvlm_processor") +@enable_hf_offline +def fixture_smolvlm_processor( + download_smolvlm_500m_instruct_model, # pylint: disable=unused-argument +): + return AutoProcessor.from_pretrained(_SMOLVLM) + + +@pytest.fixture(scope="module", name="tiny_image_path") +def fixture_tiny_image_path(tmp_path_factory) -> Path: + d = tmp_path_factory.mktemp("mm_pretrain_imgs") + p = d / "dummy.png" + arr = np.random.default_rng(0).integers(0, 255, (64, 64, 3)).astype("uint8") + Image.fromarray(arr).save(p) + return p + + +# ---- build_image_token_spec ------------------------------------------------ + + +def test_build_image_token_spec_autodetects_smolvlm(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + assert isinstance(spec, ImageTokenSpec) + assert spec.image_token == "" + assert spec.image_token_id > 0 + assert spec.image_token_id in spec.image_family_token_ids + + +def test_build_image_token_spec_honors_override(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor, override="") + assert spec.image_token == "" + + +def test_build_image_token_spec_rejects_bad_override(smolvlm_processor): + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(smolvlm_processor, override="") + + +def test_build_image_token_spec_rejects_plain_word_override(smolvlm_processor): + # Plain words BPE-tokenize but aren't placeholders. + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(smolvlm_processor, override="image") + + +# ---- check_processor_compatibility (startup-time gate) --------------------- + + +@pytest.mark.parametrize("cls_name", list(_INCOMPATIBLE_PROCESSOR_REASONS.keys())) +def test_check_processor_compatibility_rejects_incompatible(cls_name): + fake = type(cls_name, (), {})() + with pytest.raises(ValueError) as exc: + check_processor_compatibility(fake) + assert cls_name in str(exc.value) + assert _INCOMPATIBLE_PROCESSOR_REASONS[cls_name] in str(exc.value) + + +def test_check_processor_compatibility_rejects_subclass(): + # MRO-name fallback must catch user-defined subclasses. + class BaseMllama: + pass + + BaseMllama.__name__ = "MllamaProcessor" + + class CustomUserProcessor(BaseMllama): + pass + + CustomUserProcessor.__name__ = "CustomUserProcessor" + + with pytest.raises(ValueError, match="MllamaProcessor"): + check_processor_compatibility(CustomUserProcessor()) + + +def test_check_processor_compatibility_accepts_supported(smolvlm_processor): + check_processor_compatibility(smolvlm_processor) + + +# ---- MultimodalPretrainTokenizationStrategy -------------------------------- + + +def _make_strategy( + smolvlm_processor: Any, + text_column: str = "text", + image_column: str = "images", +) -> MultimodalPretrainTokenizationStrategy: + spec = build_image_token_spec(smolvlm_processor) + return MultimodalPretrainTokenizationStrategy( + PretrainTokenizer(), + smolvlm_processor.tokenizer, + False, # train_on_inputs + 2048, # sequence_len + text_column=text_column, + image_column=image_column, + image_base_dir=None, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + max_length=2048, + ) + + +def test_strategy_preserves_images_and_text(smolvlm_processor, tiny_image_path): + strat = _make_strategy(smolvlm_processor) + out = strat.tokenize_prompt( + { + "text": "\nsample transcription text", + "images": [str(tiny_image_path)], + } + ) + assert "input_ids" in out + assert "images" in out and "_mm_text" in out + assert len(out["input_ids"]) == 1 + assert len(out["images"]) == 1 + assert len(out["_mm_text"]) == 1 + assert out["images"][0] == [str(tiny_image_path)] + assert out["_mm_text"][0].startswith("") + + +def test_strategy_rejects_placeholder_count_mismatch( + smolvlm_processor, tiny_image_path +): + strat = _make_strategy(smolvlm_processor) + with pytest.raises(ValueError, match="occurrence"): + strat.tokenize_prompt( + { + "text": "\ntwo placeholders one image", + "images": [str(tiny_image_path)], + } + ) + + +def test_strategy_rejects_row_exceeding_max_length(smolvlm_processor, tiny_image_path): + spec = build_image_token_spec(smolvlm_processor) + strat = MultimodalPretrainTokenizationStrategy( + PretrainTokenizer(), + smolvlm_processor.tokenizer, + False, + 128, + text_column="text", + image_column="images", + image_base_dir=None, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + max_length=128, + ) + huge = "word " * 5000 + with pytest.raises(ValueError, match="exceeds sequence_len"): + strat.tokenize_prompt( + { + "text": f"{spec.image_token} {huge}", + "images": [str(tiny_image_path)], + } + ) + + +def test_strategy_rejects_non_list_image_column(smolvlm_processor, tiny_image_path): + strat = _make_strategy(smolvlm_processor) + with pytest.raises(ValueError, match="list"): + strat.tokenize_prompt( + { + "text": "\nbad image field", + "images": str(tiny_image_path), # should be a list + } + ) + + +@pytest.mark.parametrize("bad_value", ["", 0, False]) +def test_strategy_rejects_falsy_non_none_image_column(smolvlm_processor, bad_value): + """Falsy non-None image cells (e.g. "") are rejected, not coerced to [].""" + strat = _make_strategy(smolvlm_processor) + with pytest.raises(ValueError, match="list"): + strat.tokenize_prompt( + { + "text": "no placeholder, but bad images cell", + "images": bad_value, + } + ) + + +def test_strategy_treats_none_image_column_as_empty(smolvlm_processor): + """images=None is the only falsy value treated as a text-only row.""" + strat = _make_strategy(smolvlm_processor) + out = strat.tokenize_prompt( + { + "text": "plain text-only row, no placeholder", + "images": None, + } + ) + assert out["images"][0] == [] + + +# ---- load() factory -------------------------------------------------------- + + +def test_load_requires_processor(smolvlm_processor): + class _Cfg: + train_on_inputs = False + sequence_len = 2048 + + with pytest.raises(ValueError, match="processor"): + load(smolvlm_processor.tokenizer, _Cfg(), ds_cfg={}, processor=None) + + +def test_load_returns_strategy_with_spec(smolvlm_processor): + class _Cfg: + train_on_inputs = False + sequence_len = 2048 + + strat = load( + smolvlm_processor.tokenizer, + _Cfg(), + ds_cfg={"text_column": "text", "image_column": "images"}, + processor=smolvlm_processor, + ) + assert isinstance(strat, MultimodalPretrainTokenizationStrategy) + assert hasattr(strat, "image_token_spec") + assert strat.image_token_spec.image_token == "" diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py new file mode 100644 index 0000000000..9fc287acdf --- /dev/null +++ b/tests/test_multimodal_streaming.py @@ -0,0 +1,448 @@ +"""Multimodal CPT streaming encoder + collator tests.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import AutoProcessor + +from axolotl.prompt_strategies.multimodal_pretrain import build_image_token_spec +from axolotl.utils.collators.mm_pretrain import MultiModalPretrainDataCollator +from axolotl.utils.data.streaming import ( + encode_streaming_multimodal, + wrap_streaming_dataset, +) +from axolotl.utils.dict import DictDefault + +from tests.hf_offline_utils import enable_hf_offline + +_SMOLVLM = "HuggingFaceTB/SmolVLM-500M-Instruct" + + +@pytest.fixture(scope="module", name="smolvlm_processor") +@enable_hf_offline +def fixture_smolvlm_processor( + download_smolvlm_500m_instruct_model, # pylint: disable=unused-argument +): + return AutoProcessor.from_pretrained(_SMOLVLM) + + +@pytest.fixture(scope="module", name="two_tiny_images") +def fixture_two_tiny_images(tmp_path_factory) -> list[Path]: + d = tmp_path_factory.mktemp("mm_stream_imgs") + out = [] + for i in range(2): + p = d / f"dummy_{i}.png" + arr = np.random.default_rng(i).integers(0, 255, (64, 64, 3)).astype("uint8") + Image.fromarray(arr).save(p) + out.append(p) + return out + + +# ---- encode_streaming_multimodal ------------------------------------------ + + +def test_encode_preserves_images_and_text(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + examples = { + "text": [ + f"{spec.image_token}\nrow one", + f"{spec.image_token}\nrow two slightly longer", + ], + "images": [[str(two_tiny_images[0])], [str(two_tiny_images[1])]], + } + out = encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + assert set(out) >= {"input_ids", "labels", "attention_mask", "images", "_mm_text"} + assert len(out["input_ids"]) == 2 + assert out["images"] == [[str(two_tiny_images[0])], [str(two_tiny_images[1])]] + # EOS appended -> input_ids len equals attention_mask len and > text + for ids, mask in zip(out["input_ids"], out["attention_mask"], strict=True): + assert len(ids) == len(mask) and len(ids) > 0 + # CPT: labels == input_ids pre-masking. + for ids, lbls in zip(out["input_ids"], out["labels"], strict=True): + assert ids == lbls + + +def test_encode_rejects_mismatch(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + examples = { + "text": [f"{spec.image_token}{spec.image_token}\ntwo placeholders one image"], + "images": [[str(two_tiny_images[0])]], + } + with pytest.raises(ValueError, match="occurrence"): + encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +def test_encode_rejects_row_without_list(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + with pytest.raises(ValueError, match="list"): + encode_streaming_multimodal( + { + "text": [f"{spec.image_token}\nrow one"], + "images": [str(two_tiny_images[0])], # scalar, not a list + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +def test_encode_counts_placeholders_on_full_text(smolvlm_processor, two_tiny_images): + # All 3 placeholders must be counted even when text would have been truncated. + spec = build_image_token_spec(smolvlm_processor) + long_filler = "lorem ipsum " * 20 + text = f"{spec.image_token} {long_filler} {spec.image_token} {long_filler} {spec.image_token}" + examples = { + "text": [text], + "images": [[str(two_tiny_images[0])] * 3], + } + out = encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=4096, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + assert sum(1 for t in out["input_ids"][0] if t == spec.image_token_id) == 3 + + +def test_encode_rejects_row_exceeding_max_tokens(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + huge = "word " * 5000 + examples = { + "text": [f"{spec.image_token} {huge}"], + "images": [[str(two_tiny_images[0])]], + } + with pytest.raises(ValueError, match="exceeds sequence_len"): + encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=512, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +# ---- wrap_streaming_dataset routing -------------------------------------- + + +def test_wrap_streaming_dataset_uses_pretraining_config_arg( + smolvlm_processor, monkeypatch +): + # Eval path passes a per-entry config that may differ from cfg.pretraining_dataset[0]. + # The MM-CPT branch must read from that arg, not re-resolve from cfg. + captured = {} + + def fake_partial(fn, **kwargs): + captured["encode_fn"] = fn + captured["kwargs"] = kwargs + return lambda batch: batch + + monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", fake_partial) + + class _Dataset: + features = {"text": None, "images": None} + + def shuffle(self, **_): + return self + + def map(self, *_args, **_kwargs): + return self + + cfg = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + { + "path": "train/ds", + "type": "multimodal_pretrain", + "text_column": "wrong_train_col", + "image_column": "wrong_train_imgs", + } + ], + "sequence_len": 256, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + eval_entry = DictDefault( + { + "path": "test/ds", + "type": "multimodal_pretrain", + "text_column": "eval_text", + "image_column": "eval_imgs", + } + ) + + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=eval_entry, + ) + + assert captured["encode_fn"] is encode_streaming_multimodal + assert captured["kwargs"]["text_column"] == "eval_text" + assert captured["kwargs"]["image_column"] == "eval_imgs" + + +# ---- MultiModalPretrainDataCollator --------------------------------------- + + +def test_collator_builds_batch_and_masks_labels(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + encoded = encode_streaming_multimodal( + { + "text": [ + f"{spec.image_token}\nrow one", + f"{spec.image_token}\nrow two slightly longer", + ], + "images": [[str(two_tiny_images[0])], [str(two_tiny_images[1])]], + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + rows = [ + { + k: encoded[k][i] + for k in ("input_ids", "labels", "attention_mask", "images", "_mm_text") + } + for i in range(2) + ] + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + batch = collator.torch_call(rows) + # Expected keys + for k in ("input_ids", "attention_mask", "pixel_values", "labels"): + assert k in batch, f"missing batch key {k}" + assert isinstance(batch["input_ids"], torch.Tensor) + # Label masking check: no image-family ids remaining as valid labels. + for tid in spec.image_family_token_ids: + assert int((batch["labels"] == tid).sum().item()) == 0, ( + f"label masking left id={tid} in labels" + ) + # Pad is also masked. + pad_id = smolvlm_processor.tokenizer.pad_token_id + if pad_id is not None: + assert int((batch["labels"] == pad_id).sum().item()) == 0 + + +def test_collator_raises_on_missing_columns(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + with pytest.raises(KeyError, match="encode_streaming_multimodal"): + collator.torch_call([{"input_ids": [1, 2, 3]}]) # no _mm_text / images + + +# ---- security gates ------------------------------------------------------- + + +def test_collator_rejects_path_traversal_with_base_dir( + smolvlm_processor, two_tiny_images, tmp_path +): + spec = build_image_token_spec(smolvlm_processor) + base = tmp_path / "images" + base.mkdir() + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + image_base_dir=str(base), + ) + # Absolute path rejection + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row([str(two_tiny_images[0])], row_index=0) + assert isinstance(exc.value.__cause__, ValueError) + assert "Absolute image path" in str(exc.value.__cause__) + # Containment-escape rejection + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row(["../../../etc/passwd"], row_index=0) + assert isinstance(exc.value.__cause__, ValueError) + assert "outside" in str(exc.value.__cause__) + + +def test_collator_rejects_remote_urls(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + for url in ( + "http://example.com/a.png", + "https://x/y.jpg", + "file:///etc/passwd", + "ftp://x/y.png", + "data:image/png;base64,xxx", + # Case-variant bypass attempts. + "HTTP://evil.com/x.png", + "Https://x/y.jpg", + "FILE:///etc/passwd", + "DATA:image/png;base64,xxx", + ): + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row([url], row_index=0) + assert isinstance(exc.value.__cause__, ValueError) + assert "Non-local image path scheme" in str(exc.value.__cause__) + + +def test_collator_rejects_nul_byte_paths(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row(["bad\x00path.png"], row_index=0) + assert "NUL byte" in str(exc.value.__cause__) + + +def test_collator_rejects_non_string_image_entries(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\nrow", + "images": [None], # type: ignore[list-item] + } + ] + with pytest.raises(TypeError, match="path must be str"): + collator.torch_call(rows) + + +def test_collator_rejects_bytes_mm_text(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\nrow".encode(), + "images": [str(two_tiny_images[0])], + } + ] + with pytest.raises(TypeError, match="`_mm_text` must be str"): + collator.torch_call(rows) + + +def test_collator_sanitizes_error_message(smolvlm_processor, tmp_path): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + missing = tmp_path / "subdir_with_secret_name" / "nope.png" + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row([str(missing)], row_index=3) + # basename appears, full directory path does NOT + assert "nope.png" in str(exc.value) + assert "subdir_with_secret_name" not in str(exc.value) + assert "Row 3" in str(exc.value) + + +def test_collator_rejects_too_many_images(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + max_images_per_row=2, + ) + paths = [str(two_tiny_images[0])] * 3 + with pytest.raises(ValueError, match="max_images_per_row"): + collator._load_images_for_row(paths, row_index=0) + + +# ---- mixed / all-text batches -------------------------------------------- + + +def test_collator_all_text_batch_uses_tokenizer_fallback(smolvlm_processor): + """A batch where every row has images=[] tokenizes via the tokenizer; no pixel_values.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + rows = [ + {"_mm_text": "first text-only row", "images": []}, + {"_mm_text": "second text-only row, slightly longer", "images": []}, + ] + batch = collator.torch_call(rows) + for k in ("input_ids", "attention_mask", "labels"): + assert k in batch, f"missing batch key {k}" + assert "pixel_values" not in batch + assert isinstance(batch["input_ids"], torch.Tensor) + pad_id = smolvlm_processor.tokenizer.pad_token_id + if pad_id is not None: + assert int((batch["labels"] == pad_id).sum().item()) == 0 + + +def test_collator_mixed_batch_still_succeeds(smolvlm_processor, two_tiny_images): + """A batch with one imaged row and one text-only row still produces pixel_values.""" + spec = build_image_token_spec(smolvlm_processor) + encoded = encode_streaming_multimodal( + { + "text": [ + f"{spec.image_token}\nimaged row", + "text-only row", + ], + "images": [[str(two_tiny_images[0])], []], + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + rows = [ + { + k: encoded[k][i] + for k in ("input_ids", "labels", "attention_mask", "images", "_mm_text") + } + for i in range(2) + ] + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + batch = collator.torch_call(rows) + for k in ("input_ids", "attention_mask", "pixel_values", "labels"): + assert k in batch, f"missing batch key {k}" diff --git a/tests/utils/data/test_mm_cpt_eval.py b/tests/utils/data/test_mm_cpt_eval.py new file mode 100644 index 0000000000..57ff10897b --- /dev/null +++ b/tests/utils/data/test_mm_cpt_eval.py @@ -0,0 +1,186 @@ +"""Multimodal CPT eval-path tests.""" + +from __future__ import annotations + +import pytest + +from axolotl.utils.data.sft import ( + _create_placeholder_dataset, + _prepare_streaming_dataset, +) +from axolotl.utils.dict import DictDefault + +# ---- placeholder dataset for dispatch_batches ---------------------------- + + +def test_placeholder_text_only_keeps_existing_shape(): + """Without an MM config, the placeholder is a single-column text dataset.""" + ds = _create_placeholder_dataset() + row = next(iter(ds)) + assert "text" in row + assert "images" not in row + + +def test_placeholder_mm_emits_image_column(): + """MM placeholder rows carry the configured image column as an empty list.""" + pt_cfg = DictDefault( + { + "type": "multimodal_pretrain", + "text_column": "text", + "image_column": "images", + "multimodal": True, + } + ) + ds = _create_placeholder_dataset(pt_cfg) + row = next(iter(ds)) + assert "text" in row + assert "images" in row + assert row["images"] == [] + + +def test_placeholder_mm_honors_custom_columns(): + """Custom text_column / image_column on the MM config are reflected in the placeholder row.""" + pt_cfg = DictDefault( + { + "type": "multimodal_pretrain", + "text_column": "doc", + "image_column": "imgs", + } + ) + ds = _create_placeholder_dataset(pt_cfg) + row = next(iter(ds)) + assert "doc" in row + assert "imgs" in row + assert row["imgs"] == [] + + +# ---- multiple MM eval datasets are loaded -------------------------------- + + +def test_mm_eval_iterates_all_test_datasets(monkeypatch): + """All MM entries in test_datasets are loaded and concatenated into the eval stream.""" + cfg = DictDefault( + { + "streaming": True, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "test_datasets": [ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "multimodal_pretrain"}, + {"path": "eval/c", "type": "multimodal_pretrain"}, + ], + "max_steps": 10, + } + ) + + seen_eval_paths: list[str] = [] + + def fake_load_streaming(pretraining_config, *_a, **_kw): + path = pretraining_config["path"] + if path.startswith("eval/"): + seen_eval_paths.append(path) + return f"" + + def fake_concat(streams): + return tuple(streams) + + monkeypatch.setattr( + "axolotl.utils.data.sft._load_streaming_dataset", fake_load_streaming + ) + monkeypatch.setattr("axolotl.utils.data.sft.concatenate_datasets", fake_concat) + + train, eval_ds, _, _ = _prepare_streaming_dataset( + cfg, tokenizer=None, processor=None + ) + + assert seen_eval_paths == ["eval/a", "eval/b", "eval/c"] + assert eval_ds == ("", "", "") + + +def test_mm_eval_rejects_mixed_mm_and_non_mm_test_datasets(monkeypatch): + """MM CPT runs require every test_datasets entry to be MM; mixed lists raise.""" + cfg = DictDefault( + { + "streaming": True, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "test_datasets": [ + {"path": "eval/a", "type": "multimodal_pretrain"}, + # Plain text eval entry — not allowed alongside MM eval. + {"path": "eval/b", "type": "pretrain"}, + ], + "max_steps": 10, + } + ) + monkeypatch.setattr( + "axolotl.utils.data.sft._load_streaming_dataset", + lambda *_a, **_kw: "", + ) + with pytest.raises(ValueError, match="multimodal and non-multimodal"): + _prepare_streaming_dataset(cfg, tokenizer=None, processor=None) + + +# ---- eval collator pulls image settings from test_datasets --------------- + + +def test_eval_collator_uses_eval_image_settings(monkeypatch): + """Eval collator pulls image_base_dir / image_token from test_datasets[0]; train collator from pretraining_dataset[0].""" + from axolotl.core.builders.causal import HFCausalTrainerBuilder + + captured = {} + + class _FakeSpec: + image_token = "" + image_token_id = 7 + image_family_token_ids = (7,) + + def fake_build_image_token_spec(processor, override=None): + captured["override"] = override + return _FakeSpec() + + monkeypatch.setattr( + "axolotl.prompt_strategies.multimodal_pretrain.build_image_token_spec", + fake_build_image_token_spec, + ) + + class _FakeCollator: + def __init__(self, **kw): + captured["kwargs"] = kw + + monkeypatch.setattr( + "axolotl.core.builders.causal.MultiModalPretrainDataCollator", _FakeCollator + ) + + builder = HFCausalTrainerBuilder.__new__(HFCausalTrainerBuilder) + builder.tokenizer = object() + builder.processor = object() + builder.cfg = DictDefault( + { + "pretraining_dataset": [ + { + "type": "multimodal_pretrain", + "image_base_dir": "/train_images", + "image_token": "", + } + ], + "test_datasets": [ + { + "type": "multimodal_pretrain", + "image_base_dir": "/eval_images", + "image_token": "", + } + ], + "sequence_len": 2048, + } + ) + + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["override"] == "" + assert captured["kwargs"]["image_base_dir"] == "/eval_images" + + captured.clear() + builder._build_mm_pretrain_collator(is_eval=False) + assert captured["override"] == "" + assert captured["kwargs"]["image_base_dir"] == "/train_images" diff --git a/tests/utils/schemas/validation/test_multimodal_cpt.py b/tests/utils/schemas/validation/test_multimodal_cpt.py new file mode 100644 index 0000000000..216b78f47a --- /dev/null +++ b/tests/utils/schemas/validation/test_multimodal_cpt.py @@ -0,0 +1,270 @@ +"""Multimodal CPT config validation gates.""" + +from __future__ import annotations + +import logging + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +def _mm_cpt_cfg(min_base_cfg, **overrides) -> DictDefault: + base = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [ + { + "path": "some/ds", + "type": "multimodal_pretrain", + "image_column": "images", + } + ], + "streaming": True, + "max_steps": 10, + "processor_type": "AutoProcessor", + "sequence_len": 2048, + } + ) + ) + return base | DictDefault(overrides) + + +class TestMultimodalCPTGates: + def test_missing_processor_type_raises(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pop("processor_type", None) + with pytest.raises(ValueError, match="processor_type"): + validate_config(cfg) + + def test_sample_packing_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg, sample_packing=True) + with pytest.raises(ValueError, match="sample_packing"): + validate_config(cfg) + + def test_chat_template_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg, chat_template="tokenizer_default") + with pytest.raises(ValueError, match="chat_template"): + validate_config(cfg) + + def test_multiple_pretraining_dataset_entries_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pretraining_dataset.append({"path": "other/ds", "type": "pretrain"}) + with pytest.raises(ValueError, match="exactly one `pretraining_dataset`"): + validate_config(cfg) + + def test_multimodal_entry_in_non_first_slot_rejected(self, min_base_cfg): + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [ + {"path": "text/ds", "type": "pretrain"}, + { + "path": "mm/ds", + "type": "multimodal_pretrain", + "image_column": "images", + }, + ], + "streaming": True, + "max_steps": 10, + "processor_type": "AutoProcessor", + "sequence_len": 2048, + } + ) + ) + with pytest.raises(ValueError, match="exactly one `pretraining_dataset`"): + validate_config(cfg) + + def test_valid_cfg_passes_and_disables_remove_unused_columns(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + validated = validate_config(cfg) + assert validated.remove_unused_columns is False + pd = validated.pretraining_dataset[0] + assert pd.type == "multimodal_pretrain" + assert pd.image_column == "images" + + def test_multimodal_flag_triggers_gates(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pretraining_dataset[0]["type"] = "pretrain" + cfg.pretraining_dataset[0]["multimodal"] = True + cfg.pop("processor_type", None) + with pytest.raises(ValueError, match="processor_type"): + validate_config(cfg) + + def test_non_mm_pretraining_dataset_unaffected(self, min_base_cfg): + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [{"path": "some/ds", "type": "pretrain"}], + "streaming": True, + "max_steps": 10, + "sequence_len": 2048, + } + ) + ) + validate_config(cfg) + + def test_mm_eval_dataset_keys_preserved_through_validation(self, min_base_cfg): + """MM-specific keys on a test_datasets entry survive validate_config.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/ds", + "type": "multimodal_pretrain", + "text_column": "eval_text", + "image_column": "eval_imgs", + "image_base_dir": "/eval/images", + "image_token": "", + } + ], + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert td["text_column"] == "eval_text" + assert td["image_column"] == "eval_imgs" + assert td["image_base_dir"] == "/eval/images" + assert td["image_token"] == "" + + def test_mm_eval_dataset_via_multimodal_flag(self, min_base_cfg): + """`multimodal: true` (without type='multimodal_pretrain') opts an eval entry into MM.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/ds", + "multimodal": True, + "image_column": "imgs2", + } + ], + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert td["image_column"] == "imgs2" + assert td["multimodal"] is True + + def test_non_mm_eval_entry_does_not_match_mm_model(self, min_base_cfg): + """SFT eval entries (no MM markers) still validate as SFTDataset.""" + cfg = DictDefault( + **( + min_base_cfg + | { + "test_datasets": [ + {"path": "eval/ds", "type": "alpaca", "split": "test"} + ], + "sequence_len": 2048, + } + ) + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert "message_property_mappings" in td + assert td["type"] == "alpaca" + + def test_mm_eval_rejects_mismatched_image_base_dir(self, min_base_cfg): + """Multiple MM eval entries with different image_base_dir are rejected.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_base_dir": "/images/a", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_base_dir": "/images/b", + }, + ], + ) + with pytest.raises(ValueError, match="image_base_dir"): + validate_config(cfg) + + def test_mm_eval_rejects_mismatched_image_token(self, min_base_cfg): + """Multiple MM eval entries with different image_token overrides are rejected.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_token": "", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_token": "", + }, + ], + ) + with pytest.raises(ValueError, match="image_token"): + validate_config(cfg) + + def test_mm_eval_accepts_matching_image_base_dir(self, min_base_cfg): + """Multiple MM eval entries sharing image_base_dir validate cleanly.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_base_dir": "/images/shared", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_base_dir": "/images/shared", + }, + ], + ) + validated = validate_config(cfg) + assert len(validated.test_datasets) == 2 + + def test_mm_eval_accepts_all_unset_image_settings(self, min_base_cfg): + """Multiple MM eval entries with image_base_dir / image_token unset everywhere validate.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "multimodal_pretrain"}, + ], + ) + validated = validate_config(cfg) + assert len(validated.test_datasets) == 2 + + def test_remove_unused_columns_auto_set_emits_info_log(self, min_base_cfg, caplog): + """Auto-setting `remove_unused_columns: false` for MM CPT logs an INFO record naming the previous value.""" + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pop("remove_unused_columns", None) + with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): + validated = validate_config(cfg) + assert validated.remove_unused_columns is False + matches = [ + r + for r in caplog.records + if r.levelno == logging.INFO and "Auto-set" in r.getMessage() + ] + assert matches, "expected an INFO record about auto-setting remove_unused_columns" + msg = matches[0].getMessage() + assert "remove_unused_columns" in msg + assert "previous value: None" in msg + + def test_remove_unused_columns_already_false_does_not_log( + self, min_base_cfg, caplog + ): + """When the user already set `remove_unused_columns: false`, no auto-set log fires.""" + cfg = _mm_cpt_cfg(min_base_cfg, remove_unused_columns=False) + with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): + validate_config(cfg) + assert not any( + "Auto-set" in r.getMessage() and "remove_unused_columns" in r.getMessage() + for r in caplog.records + ) From c8f0380e9ea18f8872c70843891ade1fa55ba9ff Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 15:13:48 -0700 Subject: [PATCH 02/14] fix: address coderabbit feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses CodeRabbit review on PR #3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ (Gemma-3's `image_token` is the post-expansion soft token, not the user-facing placeholder). Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset now mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]` — keeps malformed rows from silently turning into text-only samples. Schema completeness ------------------- - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now actually reaches `load_dataset`; previously dropped at validation). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass it to `load_dataset` (was silently dropped). - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len` when unset. Validation tightening (config-load time) ---------------------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - The redundant runtime check in `sft.py` is removed; schema is the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud/object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see "Non-local scheme" instead of a confusing FileNotFoundError. - Warn when `MultiModalPretrainDataCollator.tokenizer is not processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error to note image-patch expansion may push the final length higher. Tests ----- +8 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetection, eval_sequence_len (encoder + collator, including the fallback case), `trust_remote_code` and `ds_type` preservation through validation, three modality-mismatch validation cases, tokenizer-mismatch warning, cloud-URI rejection. 68 tests pass across `tests/test_multimodal_streaming.py`, `tests/prompt_strategies/test_multimodal_pretrain.py`, `tests/utils/schemas/validation/test_multimodal_cpt.py`, `tests/utils/data/test_mm_cpt_eval.py`. Lint clean against ruff v0.15.8 (upstream pre-commit pin). --- docs/multimodal.qmd | 17 ++ src/axolotl/core/builders/causal.py | 7 +- .../prompt_strategies/multimodal_pretrain.py | 25 ++- src/axolotl/utils/collators/mm_pretrain.py | 33 +++- src/axolotl/utils/data/sft.py | 63 ++++--- src/axolotl/utils/data/streaming.py | 11 +- src/axolotl/utils/schemas/datasets.py | 12 ++ src/axolotl/utils/schemas/validation.py | 42 ++++- .../test_multimodal_pretrain.py | 15 ++ tests/test_multimodal_streaming.py | 136 ++++++++++++++ tests/utils/data/test_mm_cpt_eval.py | 166 +++++++++++++++--- .../schemas/validation/test_multimodal_cpt.py | 58 +++++- 12 files changed, 517 insertions(+), 68 deletions(-) diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index b4d7740dc3..8d56b6a4a5 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -438,6 +438,23 @@ independently. When more than one multimodal entry is provided, identical across them, because the eval collator resolves both once for the merged eval stream. +### Mixed image / text-only batches + +Rows with no images are allowed. All-text batches bypass the processor and +tokenize via the base tokenizer (no `pixel_values`). Mixed batches are +handed to the processor as-is. + +Verified VLM families: + +- HuggingFace `SmolVLM` / `SmolVLM2` +- Google `Gemma-3` (4B), `Gemma-4` (E2B, 31B) +- Alibaba `Qwen2.5-VL`, `Qwen3-VL` + +For other families (LLaVA, Qwen2-VL, Idefics3), the collator's per-row +retry isolates and names the offending row on processor failure — failures +are loud, not silent. Smoke-test before long runs, or pre-filter rows +without images. + ### Gates and rejections The following combinations are rejected at config-load time with a clear error: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index b9a6009ad7..7472ac649d 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -485,12 +485,17 @@ def _build_mm_pretrain_collator(self, pad_to_multiple_of=None, is_eval=False): spec = build_image_token_spec( self.processor, override=_mm_cpt_get(pt_cfg, "image_token") ) + max_length = ( + self.cfg.eval_sequence_len + if is_eval and self.cfg.eval_sequence_len + else self.cfg.sequence_len + ) collator_kwargs = { "tokenizer": self.tokenizer, "processor": self.processor, "image_token_spec": spec, "image_base_dir": _mm_cpt_get(pt_cfg, "image_base_dir"), - "max_length": self.cfg.sequence_len, + "max_length": max_length, } if pad_to_multiple_of is not None: collator_kwargs["pad_to_multiple_of"] = pad_to_multiple_of diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py index c716c63aa3..4e593962c4 100644 --- a/src/axolotl/prompt_strategies/multimodal_pretrain.py +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -17,6 +17,8 @@ def _get_incompatible_processor_classes() -> tuple[type, ...]: + import importlib + classes: list[type] = [] for mod_path, name in ( ("transformers.models.mllama", "MllamaProcessor"), @@ -24,8 +26,6 @@ def _get_incompatible_processor_classes() -> tuple[type, ...]: ("transformers.models.internvl", "InternVLProcessor"), ): try: - import importlib - mod = importlib.import_module(mod_path) cls = getattr(mod, name, None) if cls is not None: @@ -108,8 +108,12 @@ def resolve_id(tok: str) -> int | None: known_special_tokens: set[str] = set() try: known_special_tokens |= set(tokenizer.get_added_vocab().keys()) - except Exception: - pass + except Exception as exc: # noqa: BLE001 + LOG.debug( + "tokenizer.get_added_vocab() failed on %s: %s", + type(tokenizer).__name__, + exc, + ) known_special_tokens |= set(getattr(tokenizer, "all_special_tokens", None) or []) known_special_tokens |= set( getattr(tokenizer, "additional_special_tokens", None) or [] @@ -135,6 +139,16 @@ def resolve_id(tok: str) -> int | None: image_token = override else: proc_token = getattr(processor, "image_token", None) + # Gemma-3-style: `image_token` is the post-expansion soft token; the + # user-facing placeholder is `boi_token`. + boi_token = getattr(processor, "boi_token", None) + if ( + boi_token + and proc_token + and boi_token != proc_token + and boi_token in known_special_tokens + ): + proc_token = boi_token if proc_token is not None: image_token_id = resolve_id(proc_token) if image_token_id is not None: @@ -245,7 +259,8 @@ def tokenize_prompt(self, prompt: dict[str, Any]) -> dict[str, list]: raise ValueError( f"Multimodal CPT row tokenizes to {len(ids)} tokens which " f"exceeds sequence_len={self.max_length}. Pre-chunk your text " - f"or raise sequence_len." + f"or raise sequence_len (image patch expansion at the " + f"processor may push the final length even higher)." ) n_chunks = len(res["input_ids"]) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index f6870434f4..47a9080b95 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -49,6 +49,16 @@ def __post_init__(self) -> None: "return_tensors='pt' (in-place torch ops are used downstream)." ) check_processor_compatibility(self.processor) + # All-text batches use self.tokenizer; image batches use self.processor. + # If they don't share the same tokenizer instance, the two paths can + # tokenize the same text differently. + proc_tokenizer = getattr(self.processor, "tokenizer", None) + if proc_tokenizer is not None and proc_tokenizer is not self.tokenizer: + LOG.warning( + "MultiModalPretrainDataCollator.tokenizer is not " + "processor.tokenizer; all-text and image batches may " + "tokenize inconsistently." + ) self._image_family_token_ids = set(self.image_token_spec.image_family_token_ids) if self.image_base_dir is not None: self._base_dir_real = os.path.realpath(self.image_base_dir) @@ -60,7 +70,20 @@ def _resolve_image_path(self, p: str) -> str: raise ValueError("Image path contains embedded NUL byte.") p_lower = p.lower() if p_lower.startswith( - ("http://", "https://", "ftp://", "ftps://", "file://", "data:") + ( + "http://", + "https://", + "ftp://", + "ftps://", + "file://", + "data:", + "s3://", + "gs://", + "gcs://", + "az://", + "azure://", + "hf://", + ) ) or p.startswith(("\\\\", "//")): raise ValueError( f"Non-local image path scheme is not supported in v1 " @@ -246,6 +269,14 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: batch = self.processor(**proc_kwargs) except Exception as exc: # Pinpoint the bad row; bail to "inconclusive" if retry raises a different class. + LOG.warning( + "MultiModalPretrainDataCollator: processor failed on a batch " + "of %d rows (%s); retrying each row individually to locate " + "the offender. This adds up to %d extra processor calls.", + len(texts), + type(exc).__name__, + len(texts), + ) offender_idx: Optional[int] = None retry_ok = True retry_kwargs: dict[str, Any] = { diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 809d08ac5f..51d9db9d28 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -171,24 +171,18 @@ def _prepare_streaming_dataset( for t in test_dicts ) if is_mm_cpt_eval: - eval_streams = [] - for entry in test_dicts: - if not ( - entry.get("type") == "multimodal_pretrain" - or bool(entry.get("multimodal")) - ): - raise ValueError( - "Mixing multimodal and non-multimodal entries in " - "`test_datasets` is not supported. All eval entries " - "must be MM (type: multimodal_pretrain or " - "multimodal: true) when training is MM CPT." - ) - eval_config = _pretraining_config_from_entry(entry) - eval_streams.append( - _load_streaming_dataset( - eval_config, cfg, tokenizer, processor=processor - ) + # Modality homogeneity is enforced by check_multimodal_cpt at config + # parse time; every entry here is guaranteed to be MM. + eval_streams = [ + _load_streaming_dataset( + _pretraining_config_from_entry(entry), + cfg, + tokenizer, + processor=processor, + is_eval=True, ) + for entry in test_dicts + ] eval_dataset = ( eval_streams[0] if len(eval_streams) == 1 @@ -216,12 +210,14 @@ def _pretraining_config_from_entry(entry: dict) -> DictDefault: "skip": entry.get("skip"), "split": entry.get("split", "train"), "data_files": entry.get("data_files"), + "ds_type": entry.get("ds_type"), "type": entry.get("type", "pretrain"), "text_column": entry.get("text_column", "text"), "multimodal": entry.get("multimodal"), "image_column": entry.get("image_column", "images"), "image_base_dir": entry.get("image_base_dir"), "image_token": entry.get("image_token"), + "trust_remote_code": entry.get("trust_remote_code", False), } ) @@ -240,12 +236,14 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: "skip": 0, "split": "train", "data_files": None, + "ds_type": None, "type": "pretrain", "text_column": "text", "multimodal": None, "image_column": "images", "image_base_dir": None, "image_token": None, # nosec + "trust_remote_code": False, } ) @@ -255,6 +253,7 @@ def _load_streaming_dataset( cfg: DictDefault, tokenizer: PreTrainedTokenizer, processor: ProcessorMixin | None = None, + is_eval: bool = False, ) -> IterableDataset: """Load and prepare a streaming dataset for pretraining.""" # Create dataset wrapper partial function @@ -275,13 +274,28 @@ def _load_streaming_dataset( ): iter_dataset = _create_placeholder_dataset(pretraining_config) else: - iter_dataset = load_dataset( - pretraining_config["path"], - streaming=True, - split=pretraining_config["split"], - name=pretraining_config["name"], - data_files=pretraining_config["data_files"], - ) + ds_type = pretraining_config.get("ds_type") + if ds_type: + # ds_type names the loader (e.g. 'json'); path is the data_files glob. + iter_dataset = load_dataset( + ds_type, + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=( + pretraining_config["data_files"] or pretraining_config["path"] + ), + trust_remote_code=pretraining_config.get("trust_remote_code", False), + ) + else: + iter_dataset = load_dataset( + pretraining_config["path"], + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=pretraining_config["data_files"], + trust_remote_code=pretraining_config.get("trust_remote_code", False), + ) # Apply skip if specified if pretraining_config["skip"]: @@ -296,6 +310,7 @@ def _load_streaming_dataset( dataset_wrapper_partial, processor=processor, pretraining_config=pretraining_config, + is_eval=is_eval, ) # Format for PyTorch diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py index 966fa65719..c1e1d3be63 100644 --- a/src/axolotl/utils/data/streaming.py +++ b/src/axolotl/utils/data/streaming.py @@ -263,7 +263,14 @@ def wrap_streaming_dataset( ds_wrapper_fn, processor: Optional[ProcessorMixin] = None, pretraining_config=None, + is_eval: bool = False, ): + # Eval streams honor cfg.eval_sequence_len when set, else cfg.sequence_len. + effective_seq_len = ( + cfg.eval_sequence_len + if is_eval and getattr(cfg, "eval_sequence_len", None) + else cfg.sequence_len + ) if cfg.sample_packing: # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure # attention isolation between packed sequences @@ -340,7 +347,7 @@ def wrap_streaming_dataset( encode = functools.partial( encode_streaming_multimodal, tokenizer=tokenizer, - max_tokens=cfg.sequence_len, + max_tokens=effective_seq_len, image_token=spec.image_token, image_token_id=spec.image_token_id, text_column=text_column, @@ -350,7 +357,7 @@ def wrap_streaming_dataset( encode = functools.partial( encode_streaming, tokenizer=tokenizer, - max_tokens=cfg.sequence_len, + max_tokens=effective_seq_len, text_column=text_column, concatenate=cfg.pretraining_sample_concatenation is True, ) diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 6c55a6a9d9..70dfda799c 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -236,6 +236,12 @@ class PretrainingDataset(BaseModel): type: str | None = "pretrain" trust_remote_code: bool | None = False data_files: str | None = None + ds_type: str | None = Field( + default=None, + json_schema_extra={ + "description": "Dataset loader type when `path` points to local files (e.g. 'json', 'csv', 'parquet')." + }, + ) skip: int | None = None # Multimodal CPT fields. Opt-in via `type: multimodal_pretrain` or `multimodal: true`. @@ -275,6 +281,12 @@ class MultiModalEvalDataset(BaseModel): name: str | None = None split: str | None = "train" data_files: str | list[str] | None = None + ds_type: str | None = Field( + default=None, + json_schema_extra={ + "description": "Dataset loader type when `path` points to local files (e.g. 'json', 'csv', 'parquet')." + }, + ) skip: int | None = None type: str | None = None trust_remote_code: bool | None = False diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 6f78e31b63..8b2c395314 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1354,12 +1354,6 @@ def check_streaming_w_multiple_datasets(cls, data): @model_validator(mode="before") @classmethod def check_multimodal_cpt(cls, data): - pd = data.get("pretraining_dataset") - if not pd: - return data - - pd_list = pd if isinstance(pd, list) else [pd] - def _entry_is_mm(entry) -> bool: if isinstance(entry, dict): ds_type_ = entry.get("type") @@ -1369,6 +1363,38 @@ def _entry_is_mm(entry) -> bool: mm_flag_ = getattr(entry, "multimodal", None) return ds_type_ == "multimodal_pretrain" or bool(mm_flag_) + pd = data.get("pretraining_dataset") + pd_list = pd if isinstance(pd, list) else ([pd] if pd else []) + train_is_mm = ( + bool(pd_list) and isinstance(pd_list[0], dict) and _entry_is_mm(pd_list[0]) + ) + + test_datasets = data.get("test_datasets") or [] + test_dicts = [t for t in test_datasets if isinstance(t, dict)] + mm_flags = [_entry_is_mm(t) for t in test_dicts] + if mm_flags: + if any(mm_flags) != all(mm_flags): + raise ValueError( + "Mixing multimodal and non-multimodal entries in " + "`test_datasets` is not supported. All eval entries " + "must share modality." + ) + if all(mm_flags) and not train_is_mm: + raise ValueError( + "Multimodal `test_datasets` require multimodal CPT " + "training (set `pretraining_dataset[0].type` to " + "'multimodal_pretrain' or `multimodal: true`)." + ) + if not any(mm_flags) and train_is_mm: + raise ValueError( + "Multimodal CPT training requires multimodal " + "`test_datasets` entries (type='multimodal_pretrain' " + "or multimodal: true)." + ) + + if not pd_list: + return data + # MM config resolves from entry[0] only; multi-entry runs miscollate or silently demote. if len(pd_list) > 1 and any(_entry_is_mm(e) for e in pd_list): raise ValueError( @@ -1384,9 +1410,7 @@ def _entry_is_mm(entry) -> bool: if not isinstance(first, dict): return data - ds_type = first.get("type") - is_mm_cpt = ds_type == "multimodal_pretrain" or bool(first.get("multimodal")) - if not is_mm_cpt: + if not train_is_mm: return data if not data.get("processor_type"): diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py index e20dd9a4ae..8040472363 100644 --- a/tests/prompt_strategies/test_multimodal_pretrain.py +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -69,6 +69,21 @@ def test_build_image_token_spec_rejects_plain_word_override(smolvlm_processor): build_image_token_spec(smolvlm_processor, override="image") +def test_build_image_token_spec_prefers_boi_token_over_expansion_token( + smolvlm_processor, +): + """Gemma-3-style autodetect: `boi_token` is preferred over `image_token` + when they differ.""" + + class _FakeGemma3Like: + image_token = "" + boi_token = "" + tokenizer = smolvlm_processor.tokenizer + + spec = build_image_token_spec(_FakeGemma3Like()) + assert spec.image_token == "" + + # ---- check_processor_compatibility (startup-time gate) --------------------- diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index 9fc287acdf..69d54b4412 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -206,6 +206,97 @@ def map(self, *_args, **_kwargs): assert captured["kwargs"]["image_column"] == "eval_imgs" +def test_wrap_streaming_dataset_eval_honors_eval_sequence_len( + smolvlm_processor, monkeypatch +): + """is_eval=True with cfg.eval_sequence_len set caps encoder at eval_sequence_len.""" + captured = {} + + def fake_partial(fn, **kwargs): + captured["encode_fn"] = fn + captured["kwargs"] = kwargs + return lambda batch: batch + + monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", fake_partial) + + class _Dataset: + features = {"text": None, "images": None} + + def shuffle(self, **_): + return self + + def map(self, *_args, **_kwargs): + return self + + cfg = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "sequence_len": 4096, + "eval_sequence_len": 1024, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "test/ds", "type": "multimodal_pretrain"} + ), + is_eval=True, + ) + assert captured["kwargs"]["max_tokens"] == 1024 + + captured.clear() + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "train/ds", "type": "multimodal_pretrain"} + ), + is_eval=False, + ) + assert captured["kwargs"]["max_tokens"] == 4096 + + # eval_sequence_len unset -> eval falls back to sequence_len. + captured.clear() + cfg_no_eval = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "sequence_len": 4096, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg_no_eval, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "test/ds", "type": "multimodal_pretrain"} + ), + is_eval=True, + ) + assert captured["kwargs"]["max_tokens"] == 4096 + + # ---- MultiModalPretrainDataCollator --------------------------------------- @@ -303,11 +394,20 @@ def test_collator_rejects_remote_urls(smolvlm_processor): "file:///etc/passwd", "ftp://x/y.png", "data:image/png;base64,xxx", + # Cloud / object-store / hub URIs. + "s3://bucket/key.png", + "gs://bucket/key.png", + "gcs://bucket/key.png", + "az://container/key.png", + "azure://account/container/key.png", + "hf://datasets/foo/bar/img.png", # Case-variant bypass attempts. "HTTP://evil.com/x.png", "Https://x/y.jpg", "FILE:///etc/passwd", "DATA:image/png;base64,xxx", + "S3://bucket/key.png", + "GS://bucket/key.png", ): with pytest.raises(RuntimeError) as exc: collator._load_images_for_row([url], row_index=0) @@ -393,6 +493,42 @@ def test_collator_rejects_too_many_images(smolvlm_processor, two_tiny_images): # ---- mixed / all-text batches -------------------------------------------- +def test_collator_warns_when_tokenizer_diverges_from_processor_tokenizer( + smolvlm_processor, caplog +): + """Construct-time warning when self.tokenizer is not processor.tokenizer.""" + import logging as _logging + + spec = build_image_token_spec(smolvlm_processor) + + # Same tokenizer: no warning. + with caplog.at_level( + _logging.WARNING, logger="axolotl.utils.collators.mm_pretrain" + ): + MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + assert not any("tokenize inconsistently" in r.getMessage() for r in caplog.records) + + caplog.clear() + + # Different tokenizer instance (a stand-in object): warning fires. + class _OtherTokenizer: + pad_token_id = None + + with caplog.at_level( + _logging.WARNING, logger="axolotl.utils.collators.mm_pretrain" + ): + MultiModalPretrainDataCollator( + tokenizer=_OtherTokenizer(), + processor=smolvlm_processor, + image_token_spec=spec, + ) + assert any("tokenize inconsistently" in r.getMessage() for r in caplog.records) + + def test_collator_all_text_batch_uses_tokenizer_fallback(smolvlm_processor): """A batch where every row has images=[] tokenizes via the tokenizer; no pixel_values.""" spec = build_image_token_spec(smolvlm_processor) diff --git a/tests/utils/data/test_mm_cpt_eval.py b/tests/utils/data/test_mm_cpt_eval.py index 57ff10897b..73aa1015ec 100644 --- a/tests/utils/data/test_mm_cpt_eval.py +++ b/tests/utils/data/test_mm_cpt_eval.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from axolotl.utils.data.sft import ( _create_placeholder_dataset, _prepare_streaming_dataset, @@ -54,6 +52,88 @@ def test_placeholder_mm_honors_custom_columns(): assert row["imgs"] == [] +def test_pretraining_config_from_entry_preserves_trust_remote_code(): + """trust_remote_code on the dataset entry survives normalization.""" + from axolotl.utils.data.sft import _pretraining_config_from_entry + + cfg = _pretraining_config_from_entry( + {"path": "ds", "type": "multimodal_pretrain", "trust_remote_code": True} + ) + assert cfg["trust_remote_code"] is True + + cfg = _pretraining_config_from_entry({"path": "ds", "type": "multimodal_pretrain"}) + assert cfg["trust_remote_code"] is False + + +def test_pretraining_config_from_entry_preserves_ds_type(): + """ds_type on the dataset entry survives normalization.""" + from axolotl.utils.data.sft import _pretraining_config_from_entry + + cfg = _pretraining_config_from_entry( + {"path": "/data/*.jsonl", "type": "multimodal_pretrain", "ds_type": "json"} + ) + assert cfg["ds_type"] == "json" + + cfg = _pretraining_config_from_entry({"path": "ds", "type": "multimodal_pretrain"}) + assert cfg["ds_type"] is None + + +def test_load_streaming_dataset_routes_ds_type_to_loader(monkeypatch): + """When ds_type is set, load_dataset is called with the loader name and + path becomes data_files.""" + from axolotl.utils.data.sft import _load_streaming_dataset + + captured = {} + + def fake_load_dataset(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + class _Stub: + def skip(self, *_a, **_kw): + return self + + return _Stub() + + def fake_wrap(ds, *_a, **_kw): + return ds + + class _StubFormat: + def with_format(self, *_a, **_kw): + return self + + monkeypatch.setattr("axolotl.utils.data.sft.load_dataset", fake_load_dataset) + monkeypatch.setattr( + "axolotl.utils.data.sft.wrap_streaming_dataset", + lambda *a, **kw: _StubFormat(), + ) + + pretraining_config = DictDefault( + { + "path": "/data/shards/*.jsonl", + "name": None, + "skip": 0, + "split": "train", + "data_files": None, + "ds_type": "json", + "type": "multimodal_pretrain", + "text_column": "text", + "multimodal": True, + "image_column": "images", + "image_base_dir": None, + "image_token": None, + "trust_remote_code": False, + } + ) + cfg = DictDefault({"sequence_len": 2048, "accelerator_config": None}) + + _load_streaming_dataset(pretraining_config, cfg, tokenizer=None, processor=None) + + assert captured["args"] == ("json",) + assert captured["kwargs"]["data_files"] == "/data/shards/*.jsonl" + assert captured["kwargs"]["split"] == "train" + + # ---- multiple MM eval datasets are loaded -------------------------------- @@ -90,7 +170,7 @@ def fake_concat(streams): ) monkeypatch.setattr("axolotl.utils.data.sft.concatenate_datasets", fake_concat) - train, eval_ds, _, _ = _prepare_streaming_dataset( + _train, eval_ds, _, _ = _prepare_streaming_dataset( cfg, tokenizer=None, processor=None ) @@ -98,28 +178,8 @@ def fake_concat(streams): assert eval_ds == ("", "", "") -def test_mm_eval_rejects_mixed_mm_and_non_mm_test_datasets(monkeypatch): - """MM CPT runs require every test_datasets entry to be MM; mixed lists raise.""" - cfg = DictDefault( - { - "streaming": True, - "pretraining_dataset": [ - {"path": "train/ds", "type": "multimodal_pretrain"} - ], - "test_datasets": [ - {"path": "eval/a", "type": "multimodal_pretrain"}, - # Plain text eval entry — not allowed alongside MM eval. - {"path": "eval/b", "type": "pretrain"}, - ], - "max_steps": 10, - } - ) - monkeypatch.setattr( - "axolotl.utils.data.sft._load_streaming_dataset", - lambda *_a, **_kw: "", - ) - with pytest.raises(ValueError, match="multimodal and non-multimodal"): - _prepare_streaming_dataset(cfg, tokenizer=None, processor=None) +# Mixed MM / non-MM test_datasets is rejected at config-load time by +# check_multimodal_cpt (see tests/utils/schemas/validation/test_multimodal_cpt.py). # ---- eval collator pulls image settings from test_datasets --------------- @@ -184,3 +244,59 @@ def __init__(self, **kw): builder._build_mm_pretrain_collator(is_eval=False) assert captured["override"] == "" assert captured["kwargs"]["image_base_dir"] == "/train_images" + + +def test_eval_collator_honors_eval_sequence_len(monkeypatch): + """Eval collator uses cfg.eval_sequence_len when set; train collator uses cfg.sequence_len.""" + from axolotl.core.builders.causal import HFCausalTrainerBuilder + + captured = {} + + class _FakeSpec: + image_token = "" + image_token_id = 7 + image_family_token_ids = (7,) + + monkeypatch.setattr( + "axolotl.prompt_strategies.multimodal_pretrain.build_image_token_spec", + lambda processor, override=None: _FakeSpec(), + ) + + class _FakeCollator: + def __init__(self, **kw): + captured["kwargs"] = kw + + monkeypatch.setattr( + "axolotl.core.builders.causal.MultiModalPretrainDataCollator", _FakeCollator + ) + + builder = HFCausalTrainerBuilder.__new__(HFCausalTrainerBuilder) + builder.tokenizer = object() + builder.processor = object() + builder.cfg = DictDefault( + { + "pretraining_dataset": [{"type": "multimodal_pretrain"}], + "test_datasets": [{"type": "multimodal_pretrain"}], + "sequence_len": 4096, + "eval_sequence_len": 1024, + } + ) + + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["kwargs"]["max_length"] == 1024 + + captured.clear() + builder._build_mm_pretrain_collator(is_eval=False) + assert captured["kwargs"]["max_length"] == 4096 + + # eval_sequence_len unset -> eval falls back to sequence_len + builder.cfg = DictDefault( + { + "pretraining_dataset": [{"type": "multimodal_pretrain"}], + "test_datasets": [{"type": "multimodal_pretrain"}], + "sequence_len": 4096, + } + ) + captured.clear() + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["kwargs"]["max_length"] == 4096 diff --git a/tests/utils/schemas/validation/test_multimodal_cpt.py b/tests/utils/schemas/validation/test_multimodal_cpt.py index 216b78f47a..d6e8a88181 100644 --- a/tests/utils/schemas/validation/test_multimodal_cpt.py +++ b/tests/utils/schemas/validation/test_multimodal_cpt.py @@ -240,6 +240,60 @@ def test_mm_eval_accepts_all_unset_image_settings(self, min_base_cfg): validated = validate_config(cfg) assert len(validated.test_datasets) == 2 + def test_mixed_modality_test_datasets_rejected_at_validation(self, min_base_cfg): + """A test_datasets list mixing MM and non-MM entries fails at config-load.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "alpaca", "split": "test"}, + ], + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Mixing multimodal and non-multimodal" in msg + assert "test_datasets" in msg + assert "share modality" in msg + + def test_mm_test_datasets_with_text_training_rejected(self, min_base_cfg): + """MM test_datasets paired with non-MM training fails at config-load.""" + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [{"path": "text/ds", "type": "pretrain"}], + "test_datasets": [ + {"path": "eval/a", "type": "multimodal_pretrain"} + ], + "streaming": True, + "max_steps": 10, + "sequence_len": 2048, + "processor_type": "AutoProcessor", + } + ) + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Multimodal `test_datasets`" in msg + assert "multimodal CPT training" in msg + assert "multimodal_pretrain" in msg + + def test_text_test_datasets_with_mm_training_rejected(self, min_base_cfg): + """Non-MM test_datasets paired with MM training fails at config-load.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[{"path": "eval/a", "type": "alpaca", "split": "test"}], + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Multimodal CPT training" in msg + assert "multimodal `test_datasets`" in msg + assert "multimodal_pretrain" in msg + def test_remove_unused_columns_auto_set_emits_info_log(self, min_base_cfg, caplog): """Auto-setting `remove_unused_columns: false` for MM CPT logs an INFO record naming the previous value.""" cfg = _mm_cpt_cfg(min_base_cfg) @@ -252,7 +306,9 @@ def test_remove_unused_columns_auto_set_emits_info_log(self, min_base_cfg, caplo for r in caplog.records if r.levelno == logging.INFO and "Auto-set" in r.getMessage() ] - assert matches, "expected an INFO record about auto-setting remove_unused_columns" + assert matches, ( + "expected an INFO record about auto-setting remove_unused_columns" + ) msg = matches[0].getMessage() assert "remove_unused_columns" in msg assert "previous value: None" in msg From 98d7e52dae9f25d34ba7725c444b8bab333f2c42 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 15:31:32 -0700 Subject: [PATCH 03/14] fix: address coderabbit comments/nits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses CodeRabbit review on PR #3629. No behavior change for the happy path; expands schema, hardens fallbacks, tightens validation. Bug fixes --------- - Gemma-3 autodetect: prefer `processor.boi_token` over `image_token` when they differ. Without this, MM CPT crashed on the first batch with "Prompt contained 0 image tokens". - `dispatch_batches: true` placeholder dataset mirrors the configured `image_column` so worker ranks don't KeyError on empty rows. - `tokenize_prompt` rejects falsy non-None image cells (`""`, `0`, `False`) instead of coercing to `[]`. - `_tokenize` now honors `add_eos_token` / `strip_bos_token` instead of silently ignoring them. Schema ------ - Add `ds_type` to `PretrainingDataset` and `MultiModalEvalDataset` (the documented `ds_type: json` shape now reaches `load_dataset`). - Preserve `trust_remote_code` through `_pretraining_config_from_entry` and pass to `load_dataset`. - Honor `cfg.eval_sequence_len` in MM CPT eval streams (encoder + collator) with documented fallback to `cfg.sequence_len`. Validation (config-load time) ----------------------------- - Reject mixed multimodal/text entries in `test_datasets`. - Reject MM `test_datasets` paired with non-MM training. - Reject non-MM `test_datasets` paired with MM training. - Removed the redundant runtime check in `sft.py`; the schema is now the single source of truth. Hardening / observability ------------------------- - Mixed/all-text batch handling: collator routes all-text batches to the tokenizer (no `pixel_values`); mixed batches go through the processor as-is. Documented per-VLM compatibility (verified on SmolVLM/SmolVLM2, Gemma-3, Gemma-4, Qwen2.5-VL, Qwen3-VL). - Reject cloud / object-store URIs (`s3://`, `gs://`, `gcs://`, `az://`, `azure://`, `hf://`) in image paths so users see the explicit "Non-local scheme" error instead of a confusing FileNotFoundError. - Warn at construction when `MultiModalPretrainDataCollator.tokenizer` is not `processor.tokenizer` (all-text vs image batches could otherwise tokenize the same text differently). - Warn at retry kickoff when a processor call fails on a batch, so users see why processing stalls during per-row diagnosis. - INFO log when `remove_unused_columns` is auto-set to `false` for MM CPT. - DEBUG log when `tokenizer.get_added_vocab()` fails (was silent pass). - Clarify "exceeds sequence_len" error in both encoder paths to note image-patch expansion may push the final length higher. Code quality ------------ - Lift `image_token_spec` into `MultimodalPretrainTokenizationStrategy. __init__` instead of post-construction monkey-patch + `type: ignore`. - Hoist `import importlib` out of the per-class loop. - Drop dead `n_chunks` multiplication; replace with explicit invariant assertion. - Replace ambiguous `×` (U+00D7) with ASCII `x` in code/comments and the user-facing pixel-cap error. Tests ----- +15 regression tests across the four MM CPT suites covering: Gemma-3 boi_token autodetect (with id-mapping assertion), `eval_sequence_len` on encoder + collator (set + unset-fallback), `trust_remote_code` and `ds_type` preservation, three modality-mismatch validation cases, tokenizer-mismatch warning, `remove_unused_columns` auto-set log, cloud-URI rejection. 68 tests pass across `test_multimodal_streaming`, `test_multimodal_pretrain`, `test_multimodal_cpt`, and `test_mm_cpt_eval`. Lint clean against ruff v0.15.8 (upstream pre-commit pin). --- .../prompt_strategies/multimodal_pretrain.py | 29 ++++++++++++++----- src/axolotl/utils/collators/mm_pretrain.py | 6 ++-- .../test_multimodal_pretrain.py | 13 +++++++-- tests/test_multimodal_streaming.py | 12 ++++++-- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py index 4e593962c4..fcfc7ccdeb 100644 --- a/src/axolotl/prompt_strategies/multimodal_pretrain.py +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -45,7 +45,7 @@ def _get_incompatible_processor_classes() -> tuple[type, ...]: "", ) -# Without masking these in labels, loss blows up ~10× on Qwen/SmolVLM. +# Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM. _IMAGE_FAMILY_TOKEN_CANDIDATES: tuple[str, ...] = ( "", "<|image|>", @@ -210,6 +210,7 @@ def __init__( image_token_id: int, image_column: str = "images", image_base_dir: str | None = None, + image_token_spec: ImageTokenSpec | None = None, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) @@ -217,6 +218,7 @@ def __init__( self.image_token_id = image_token_id self.image_column = image_column self.image_base_dir = image_base_dir + self.image_token_spec = image_token_spec def _tokenize( self, @@ -227,8 +229,18 @@ def _tokenize( # No truncation: collator re-tokenizes the full text without truncation; # truncating here decouples the stored ids from what the model receives. res = self.tokenizer(prompt, add_special_tokens=True) - res["input_ids"] = [res["input_ids"] + [self.tokenizer.eos_token_id]] - res["attention_mask"] = [res["attention_mask"] + [1]] + ids = list(res["input_ids"]) + mask = list(res["attention_mask"]) + bos_id = getattr(self.tokenizer, "bos_token_id", None) + if strip_bos_token and ids and bos_id is not None and ids[0] == bos_id: + ids = ids[1:] + mask = mask[1:] + eos_id = getattr(self.tokenizer, "eos_token_id", None) + if add_eos_token and eos_id is not None: + ids = ids + [eos_id] + mask = mask + [1] + res["input_ids"] = [ids] + res["attention_mask"] = [mask] return res def tokenize_prompt(self, prompt: dict[str, Any]) -> dict[str, list]: @@ -263,9 +275,12 @@ def tokenize_prompt(self, prompt: dict[str, Any]) -> dict[str, list]: f"processor may push the final length even higher)." ) - n_chunks = len(res["input_ids"]) - res["images"] = [list(images)] * n_chunks - res["_mm_text"] = [text] * n_chunks + # `_tokenize` produces exactly one chunk; the assert keeps that + # invariant explicit so a future change there can't silently + # mis-align `images` / `_mm_text` against `input_ids`. + assert len(res["input_ids"]) == 1 + res["images"] = [list(images)] + res["_mm_text"] = [text] return res @@ -307,6 +322,6 @@ def load( image_token=spec.image_token, image_token_id=spec.image_token_id, max_length=cfg.sequence_len, + image_token_spec=spec, ) - strat.image_token_spec = spec # type: ignore[attr-defined] return strat diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index 47a9080b95..845ddd86a6 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -20,7 +20,7 @@ LOG = get_logger(__name__) -# Decompression-bomb cap (~7070×7070). +# Decompression-bomb cap (~7070x7070). _DEFAULT_MAX_IMAGE_PIXELS = 50_000_000 _DEFAULT_MAX_IMAGES_PER_ROW = 32 @@ -128,7 +128,7 @@ def _open_image_hardened(self, resolved: str) -> Image.Image: w, h = src.size if w * h > self.max_image_pixels: raise ValueError( - f"Image pixels ({w}×{h}) exceed " + f"Image pixels ({w}x{h}) exceed " f"max_image_pixels ({self.max_image_pixels})." ) # Multi-frame bomb guard (GIF/TIFF/WebP). @@ -329,7 +329,7 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: if pad_id is not None: labels[labels == pad_id] = -100 - # Without this, image-family ids dominate loss and blow it up ~10×. + # Without this, image-family ids dominate loss and blow it up ~10x. for tid in self._image_family_token_ids: labels[labels == tid] = -100 diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py index 8040472363..f809872726 100644 --- a/tests/prompt_strategies/test_multimodal_pretrain.py +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -73,15 +73,24 @@ def test_build_image_token_spec_prefers_boi_token_over_expansion_token( smolvlm_processor, ): """Gemma-3-style autodetect: `boi_token` is preferred over `image_token` - when they differ.""" + when they differ. Resolved id must match the boi_token, not image_token.""" + tok = smolvlm_processor.tokenizer + image_id = tok.convert_tokens_to_ids("") + boi_id = tok.convert_tokens_to_ids("") + assert boi_id != image_id, ( + "fixture assumption broken: SmolVLM tokenizer should map these to distinct ids" + ) class _FakeGemma3Like: image_token = "" boi_token = "" - tokenizer = smolvlm_processor.tokenizer + tokenizer = tok spec = build_image_token_spec(_FakeGemma3Like()) assert spec.image_token == "" + assert spec.image_token_id == boi_id + assert spec.image_token_id != image_id + assert boi_id in spec.image_family_token_ids # ---- check_processor_compatibility (startup-time gate) --------------------- diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index 69d54b4412..74b4f6893c 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -105,9 +105,11 @@ def test_encode_rejects_row_without_list(smolvlm_processor, two_tiny_images): def test_encode_counts_placeholders_on_full_text(smolvlm_processor, two_tiny_images): - # All 3 placeholders must be counted even when text would have been truncated. + # The last placeholder must remain countable even when it's hundreds of + # tokens deep — guards against a regression that adds tokenizer + # truncation and silently drops trailing placeholders. spec = build_image_token_spec(smolvlm_processor) - long_filler = "lorem ipsum " * 20 + long_filler = "lorem ipsum " * 400 text = f"{spec.image_token} {long_filler} {spec.image_token} {long_filler} {spec.image_token}" examples = { "text": [text], @@ -120,7 +122,11 @@ def test_encode_counts_placeholders_on_full_text(smolvlm_processor, two_tiny_ima image_token=spec.image_token, image_token_id=spec.image_token_id, ) - assert sum(1 for t in out["input_ids"][0] if t == spec.image_token_id) == 3 + ids = out["input_ids"][0] + # Sanity: the input is genuinely long, so a truncating regression would + # have to cut into it to drop the last placeholder. + assert len(ids) > 2000 + assert sum(1 for t in ids if t == spec.image_token_id) == 3 def test_encode_rejects_row_exceeding_max_tokens(smolvlm_processor, two_tiny_images): From 6cd74f079feed3bb257d0cd2bf000931a190cf81 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 15:53:49 -0700 Subject: [PATCH 04/14] fix: valid coderabbit nit and caplog test capture - mm_pretrain.py: return BatchEncoding (not dict) from all-text branch so it matches the imaged path. - test_multimodal_cpt.py, test_multimodal_streaming.py: monkeypatch axolotl logger propagate=True so caplog can capture records (axolotl's logging config sets propagate=False, blocking root capture in CI). --- src/axolotl/utils/collators/mm_pretrain.py | 2 +- tests/test_multimodal_streaming.py | 5 ++++- .../utils/schemas/validation/test_multimodal_cpt.py | 12 ++++++++++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index 845ddd86a6..4470e4f25e 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -253,7 +253,7 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: for tid in self._image_family_token_ids: tok_labels[tok_labels == tid] = -100 batch["labels"] = tok_labels - return dict(batch) + return batch # No truncation: it chops input_ids mid-placeholder while pixel_values # keep every image — silent text/pixel mismatch. We warn post-hoc instead. diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index 74b4f6893c..dfb500fd89 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -500,11 +500,14 @@ def test_collator_rejects_too_many_images(smolvlm_processor, two_tiny_images): def test_collator_warns_when_tokenizer_diverges_from_processor_tokenizer( - smolvlm_processor, caplog + smolvlm_processor, caplog, monkeypatch ): """Construct-time warning when self.tokenizer is not processor.tokenizer.""" import logging as _logging + # `axolotl` logger has propagate=False (logging_config.py); flip it so + # caplog's root handler receives the record. + monkeypatch.setattr(_logging.getLogger("axolotl"), "propagate", True) spec = build_image_token_spec(smolvlm_processor) # Same tokenizer: no warning. diff --git a/tests/utils/schemas/validation/test_multimodal_cpt.py b/tests/utils/schemas/validation/test_multimodal_cpt.py index d6e8a88181..2b9123bbff 100644 --- a/tests/utils/schemas/validation/test_multimodal_cpt.py +++ b/tests/utils/schemas/validation/test_multimodal_cpt.py @@ -294,8 +294,13 @@ def test_text_test_datasets_with_mm_training_rejected(self, min_base_cfg): assert "multimodal `test_datasets`" in msg assert "multimodal_pretrain" in msg - def test_remove_unused_columns_auto_set_emits_info_log(self, min_base_cfg, caplog): + def test_remove_unused_columns_auto_set_emits_info_log( + self, min_base_cfg, caplog, monkeypatch + ): """Auto-setting `remove_unused_columns: false` for MM CPT logs an INFO record naming the previous value.""" + # `axolotl` logger has propagate=False (logging_config.py); flip it so + # caplog's root handler receives the record. + monkeypatch.setattr(logging.getLogger("axolotl"), "propagate", True) cfg = _mm_cpt_cfg(min_base_cfg) cfg.pop("remove_unused_columns", None) with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): @@ -314,9 +319,12 @@ def test_remove_unused_columns_auto_set_emits_info_log(self, min_base_cfg, caplo assert "previous value: None" in msg def test_remove_unused_columns_already_false_does_not_log( - self, min_base_cfg, caplog + self, min_base_cfg, caplog, monkeypatch ): """When the user already set `remove_unused_columns: false`, no auto-set log fires.""" + monkeypatch.setattr( + logging.getLogger("axolotl.utils.schemas.validation"), "propagate", True + ) cfg = _mm_cpt_cfg(min_base_cfg, remove_unused_columns=False) with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): validate_config(cfg) From 40ea3b9539278847650ee2419faa4badaae70431 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 16:16:26 -0700 Subject: [PATCH 05/14] fix: gemma 4 regression, expand test coverage per Codecov multimodal_pretrain.py: scope the boi_token swap in build_image_token_spec to processors whose `image_token` name contains "soft_token" (the Gemma-3 convention). Without this, Gemma-4 (`image_token=<|image|>`, `boi_token=<|image>`) gets the wrong placeholder autodetected and every row fails validation with a 0-vs-N placeholder/image mismatch. test_multimodal_streaming.py: 6 new tests - Two for the new autodetection behavior (Gemma-4 keeps image_token, Gemma-3 still swaps to boi_token), using stub processors. - Three branch-coverage tests for build_image_token_spec failure modes: override not registered as special token, override resolves to unk, nothing autodetectable. - Three collator-path tests: skip_bad_images drops a row and continues, all-rows-dropped surfaces a RuntimeError, multi-frame GIF triggers the animation-bomb guard via _open_image_hardened. --- .../prompt_strategies/multimodal_pretrain.py | 8 +- tests/test_multimodal_streaming.py | 139 ++++++++++++++++++ 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py index fcfc7ccdeb..13bb77030a 100644 --- a/src/axolotl/prompt_strategies/multimodal_pretrain.py +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -139,14 +139,18 @@ def resolve_id(tok: str) -> int | None: image_token = override else: proc_token = getattr(processor, "image_token", None) - # Gemma-3-style: `image_token` is the post-expansion soft token; the - # user-facing placeholder is `boi_token`. + # Gemma-3-style only: `image_token` is the post-expansion soft token + # (its name literally contains "soft_token"); the user-facing + # placeholder is `boi_token`. Gemma-4 reverses this — `image_token` + # IS the user-facing placeholder (`<|image|>`) and `boi_token` + # (`<|image>`) is just a bracket marker, so don't blindly swap. boi_token = getattr(processor, "boi_token", None) if ( boi_token and proc_token and boi_token != proc_token and boi_token in known_special_tokens + and "soft_token" in proc_token ): proc_token = boi_token if proc_token is not None: diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index dfb500fd89..f95c8eb07c 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -146,6 +146,80 @@ def test_encode_rejects_row_exceeding_max_tokens(smolvlm_processor, two_tiny_ima ) +# ---- build_image_token_spec autodetection -------------------------------- + + +class _StubTokenizer: + """Minimal tokenizer stub for autodetection tests.""" + + def __init__(self, vocab: dict[str, int], unk_id: int = 0): + self._vocab = vocab + self.unk_token_id = unk_id + self.all_special_tokens = list(vocab.keys()) + self.additional_special_tokens: list[str] = [] + + def get_added_vocab(self): + return dict(self._vocab) + + def convert_tokens_to_ids(self, tok): + return self._vocab.get(tok, self.unk_token_id) + + +class _StubProcessor: + def __init__(self, tokenizer, image_token=None, boi_token=None): + self.tokenizer = tokenizer + if image_token is not None: + self.image_token = image_token + if boi_token is not None: + self.boi_token = boi_token + + +def test_build_image_token_spec_gemma4_uses_image_token_not_boi(): + """Gemma-4: `image_token` is the user-facing placeholder; don't swap to boi_token.""" + tok = _StubTokenizer({"<|image|>": 258880, "<|image>": 255999}) + proc = _StubProcessor(tok, image_token="<|image|>", boi_token="<|image>") + spec = build_image_token_spec(proc) + assert spec.image_token == "<|image|>" + assert spec.image_token_id == 258880 + + +def test_build_image_token_spec_gemma3_swaps_to_boi_token(): + """Gemma-3: `image_token` is the post-expansion soft token; placeholder is `boi_token`.""" + tok = _StubTokenizer( + {"": 262144, "": 255999} + ) + proc = _StubProcessor( + tok, image_token="", boi_token="" + ) + spec = build_image_token_spec(proc) + assert spec.image_token == "" + assert spec.image_token_id == 255999 + + +def test_build_image_token_spec_override_not_special_rejected(): + """Override that isn't a registered special token is rejected (would BPE-tokenize).""" + tok = _StubTokenizer({"<|image|>": 258880}) + proc = _StubProcessor(tok, image_token="<|image|>") + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(proc, override="not_a_real_token") + + +def test_build_image_token_spec_override_resolves_to_unk_rejected(): + """Override that resolves to unk is rejected with a clear error.""" + tok = _StubTokenizer({"<|image|>": 258880, "<|fake|>": 0}, unk_id=0) + proc = _StubProcessor(tok, image_token="<|image|>") + with pytest.raises(ValueError, match="did not resolve"): + build_image_token_spec(proc, override="<|fake|>") + + +def test_build_image_token_spec_no_candidates_raises(): + """If neither processor attrs nor any known candidate resolve, raise a clear error.""" + tok = _StubTokenizer({}) # nothing registered + proc = _StubProcessor(tok) # no image_token, no boi_token + with pytest.raises(ValueError, match="Could not autodetect"): + build_image_token_spec(proc) + + # ---- wrap_streaming_dataset routing -------------------------------------- @@ -496,6 +570,71 @@ def test_collator_rejects_too_many_images(smolvlm_processor, two_tiny_images): collator._load_images_for_row(paths, row_index=0) +def test_collator_skip_bad_images_drops_row_and_continues( + smolvlm_processor, two_tiny_images, tmp_path +): + """skip_bad_images=True: bad row drops, batch survives on remaining rows.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + skip_bad_images=True, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\ngood row", + "images": [str(two_tiny_images[0])], + }, + { + "_mm_text": f"{spec.image_token}\nbad row", + "images": [str(tmp_path / "missing.png")], + }, + ] + batch = collator.torch_call(rows) + # Surviving row produced a batch with pixel_values from the good image. + assert "input_ids" in batch and "pixel_values" in batch + assert batch["input_ids"].shape[0] == 1 + + +def test_collator_all_rows_dropped_raises(smolvlm_processor, tmp_path): + """skip_bad_images=True with every row failing surfaces a RuntimeError.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + skip_bad_images=True, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\nrow", + "images": [str(tmp_path / f"missing_{i}.png")], + } + for i in range(2) + ] + with pytest.raises(RuntimeError, match="All rows in the batch were dropped"): + collator.torch_call(rows) + + +def test_collator_rejects_multi_frame_image(smolvlm_processor, tmp_path): + """Multi-frame GIF is rejected by the in-process bomb guard.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + # Build an in-memory 2-frame GIF; verify the inner guard fires (the + # outer `_load_images_for_row` wraps it in a basename-only RuntimeError). + gif_path = tmp_path / "anim.gif" + f0 = Image.new("RGB", (16, 16), color=(255, 0, 0)) + f1 = Image.new("RGB", (16, 16), color=(0, 255, 0)) + f0.save(gif_path, save_all=True, append_images=[f1], duration=100, loop=0) + with pytest.raises(ValueError, match="Multi-frame"): + collator._open_image_hardened(str(gif_path)) + + # ---- mixed / all-text batches -------------------------------------------- From c5b3a810e13333b2cd2273991be2f62f5942a271 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 16:28:01 -0700 Subject: [PATCH 06/14] fix: tests + ruff/lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix(test): patch parent `axolotl` logger so negative caplog assertion has teeth The previous monkeypatch targeted `axolotl.utils.schemas.validation`, which is already propagate=True by inheritance — the actual block sits one level up at the `axolotl` logger (propagate=False from logging_config.py). The result: caplog never received any records, and `assert not any("Auto-set" ... in caplog.records)` would have passed even if the regression fired. Mirror the positive test by flipping propagate on `logging.getLogger("axolotl")` and add a comment explaining why the leaf isn't the right target. --- tests/test_multimodal_streaming.py | 4 +--- tests/utils/schemas/validation/test_multimodal_cpt.py | 8 +++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index f95c8eb07c..458f5ac02a 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -185,9 +185,7 @@ def test_build_image_token_spec_gemma4_uses_image_token_not_boi(): def test_build_image_token_spec_gemma3_swaps_to_boi_token(): """Gemma-3: `image_token` is the post-expansion soft token; placeholder is `boi_token`.""" - tok = _StubTokenizer( - {"": 262144, "": 255999} - ) + tok = _StubTokenizer({"": 262144, "": 255999}) proc = _StubProcessor( tok, image_token="", boi_token="" ) diff --git a/tests/utils/schemas/validation/test_multimodal_cpt.py b/tests/utils/schemas/validation/test_multimodal_cpt.py index 2b9123bbff..f110d6c7e1 100644 --- a/tests/utils/schemas/validation/test_multimodal_cpt.py +++ b/tests/utils/schemas/validation/test_multimodal_cpt.py @@ -322,9 +322,11 @@ def test_remove_unused_columns_already_false_does_not_log( self, min_base_cfg, caplog, monkeypatch ): """When the user already set `remove_unused_columns: false`, no auto-set log fires.""" - monkeypatch.setattr( - logging.getLogger("axolotl.utils.schemas.validation"), "propagate", True - ) + # Mirror the positive test: flip propagate on the parent `axolotl` + # logger (the one with propagate=False), not the leaf — otherwise + # caplog never sees axolotl.* records and this negative assertion + # is vacuous (it would pass even if the auto-set log fired). + monkeypatch.setattr(logging.getLogger("axolotl"), "propagate", True) cfg = _mm_cpt_cfg(min_base_cfg, remove_unused_columns=False) with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): validate_config(cfg) From 10ca6a25f5d38b76e8dabef83fcdc3cc4745b2e6 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 16:35:31 -0700 Subject: [PATCH 07/14] fix coderabbit nit - use tokenizer for text-only rows in offender retry loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some HF processors reject `images=[[]]`, which made the per-row retry flag innocent text-only rows as the offender. Mirror the all-text bypass — diagnostic-only path, mainline unchanged. --- src/axolotl/utils/collators/mm_pretrain.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index 4470e4f25e..a8fa71a807 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -287,7 +287,12 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: retry_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of for i, (t, imgs) in enumerate(zip(texts, images, strict=True)): try: - self.processor(text=[t], images=[imgs], **retry_kwargs) + if len(imgs) == 0: + # Some processors reject `images=[[]]` — would mislabel + # text-only rows as the offender. + self.tokenizer(text=[t], **retry_kwargs) + else: + self.processor(text=[t], images=[imgs], **retry_kwargs) except Exception as retry_exc: if isinstance(retry_exc, type(exc)) or isinstance( exc, type(retry_exc) From 0dcf37d8c51d90182c0947ececdf49ffab53f2fd Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Sat, 25 Apr 2026 16:53:05 -0700 Subject: [PATCH 08/14] fix: test (invert prefers_boi assertion to match new heuristic) Fix text following Gemma 4 regression fix --- .../test_multimodal_pretrain.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py index f809872726..3bdec1b273 100644 --- a/tests/prompt_strategies/test_multimodal_pretrain.py +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -69,11 +69,13 @@ def test_build_image_token_spec_rejects_plain_word_override(smolvlm_processor): build_image_token_spec(smolvlm_processor, override="image") -def test_build_image_token_spec_prefers_boi_token_over_expansion_token( +def test_build_image_token_spec_keeps_image_token_when_no_soft_token_in_name( smolvlm_processor, ): - """Gemma-3-style autodetect: `boi_token` is preferred over `image_token` - when they differ. Resolved id must match the boi_token, not image_token.""" + """Non-Gemma-3 processors: the boi-swap heuristic only fires when + `image_token` name contains "soft_token" (Gemma-3 convention). Otherwise + `image_token` IS the user-facing placeholder (Gemma-4 convention) and + must not be silently replaced by `boi_token`.""" tok = smolvlm_processor.tokenizer image_id = tok.convert_tokens_to_ids("") boi_id = tok.convert_tokens_to_ids("") @@ -81,16 +83,15 @@ def test_build_image_token_spec_prefers_boi_token_over_expansion_token( "fixture assumption broken: SmolVLM tokenizer should map these to distinct ids" ) - class _FakeGemma3Like: - image_token = "" + class _FakeGemma4Like: + image_token = "" # no 'soft_token' in name → must not swap boi_token = "" tokenizer = tok - spec = build_image_token_spec(_FakeGemma3Like()) - assert spec.image_token == "" - assert spec.image_token_id == boi_id - assert spec.image_token_id != image_id - assert boi_id in spec.image_family_token_ids + spec = build_image_token_spec(_FakeGemma4Like()) + assert spec.image_token == "" + assert spec.image_token_id == image_id + assert spec.image_token_id != boi_id # ---- check_processor_compatibility (startup-time gate) --------------------- From 950785a8611a293c0cac46ee9cdfe566b3e49824 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 7 May 2026 07:44:33 -0700 Subject: [PATCH 09/14] fix(mm-cpt): re-append EOS in collator after processor re-tokenization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MultiModalPretrainDataCollator.torch_call calls processor(text=...) which re-tokenizes _mm_text from scratch, discarding the EOS that encode_streaming_multimodal appended to input_ids. Without this, labels never contain EOS at end-of-document and the model never learns to emit a stop token — symptoms: non-terminating / repetitive generation. Match the text CPT contract (encode_streaming keeps EOS in both input_ids and labels) by appending EOS to _mm_text, idempotently, gated on a new add_eos_token field (default True). --- src/axolotl/utils/collators/mm_pretrain.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index a8fa71a807..ec9d6e2c4f 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -38,6 +38,7 @@ class MultiModalPretrainDataCollator(DataCollatorMixin): skip_bad_images: bool = False max_image_pixels: int = _DEFAULT_MAX_IMAGE_PIXELS max_images_per_row: int = _DEFAULT_MAX_IMAGES_PER_ROW + add_eos_token: bool = True _image_family_token_ids: set[int] = field(init=False, default_factory=set) _base_dir_real: Optional[str] = field(init=False, default=None) @@ -210,6 +211,10 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: f"Row {i}, image {j}: path must be str, got " f"{type(rp).__name__}." ) + # Processor re-tokenizes below, discarding the encoder's EOS — re-append. + if self.add_eos_token and self.tokenizer.eos_token: + if not mm_text.endswith(self.tokenizer.eos_token): + mm_text = mm_text + self.tokenizer.eos_token texts.append(mm_text) loaded = self._load_images_for_row(raw_paths, row_index=i) if self.skip_bad_images and len(loaded) != len(raw_paths): From 85e0c6dcf05b8b4061e30df17fe9fefb81c0eb04 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 14 May 2026 22:09:33 -0700 Subject: [PATCH 10/14] Align PretrainingDataset text_column/data_files with MultiModalEvalDataset --- src/axolotl/utils/schemas/datasets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 70dfda799c..ad186d20db 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -232,10 +232,13 @@ class PretrainingDataset(BaseModel): name: str | None = None path: str | None = None split: str | None = "train" - text_column: str | None = "text" + text_column: str | None = Field( + default="text", + json_schema_extra={"description": "Column holding the row's text."}, + ) type: str | None = "pretrain" trust_remote_code: bool | None = False - data_files: str | None = None + data_files: str | list[str] | None = None ds_type: str | None = Field( default=None, json_schema_extra={ From f413bee577fbb74f7880ef2b5f3e0902afccc632 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Thu, 14 May 2026 22:39:24 -0700 Subject: [PATCH 11/14] Strip speculative image-loader hardening from MM CPT collator Aligns mm_pretrain.py with mm_chat.py's image-loading posture. Drops NUL/URL/path-traversal/pixel-cap/multi-frame/per-row-count guards that defended against threat models that don't apply to a CLI trainer loading its own dataset. Routes image loading through transformers.image_utils.load_image, matching the chat path. Keeps image_base_dir join, skip_bad_images, label masking, processor compatibility check, and the tokenizer/processor mismatch warning. --- src/axolotl/utils/collators/mm_pretrain.py | 145 ++++----------------- tests/test_multimodal_streaming.py | 120 +---------------- 2 files changed, 29 insertions(+), 236 deletions(-) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index ec9d6e2c4f..1a1051224e 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -10,6 +10,7 @@ from torch import Tensor from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers.data.data_collator import DataCollatorMixin +from transformers.image_utils import load_image from transformers.utils import PaddingStrategy from axolotl.prompt_strategies.multimodal_pretrain import ( @@ -20,10 +21,6 @@ LOG = get_logger(__name__) -# Decompression-bomb cap (~7070x7070). -_DEFAULT_MAX_IMAGE_PIXELS = 50_000_000 -_DEFAULT_MAX_IMAGES_PER_ROW = 32 - @dataclass class MultiModalPretrainDataCollator(DataCollatorMixin): @@ -36,12 +33,9 @@ class MultiModalPretrainDataCollator(DataCollatorMixin): pad_to_multiple_of: Optional[int] = None max_length: Optional[int] = None skip_bad_images: bool = False - max_image_pixels: int = _DEFAULT_MAX_IMAGE_PIXELS - max_images_per_row: int = _DEFAULT_MAX_IMAGES_PER_ROW add_eos_token: bool = True _image_family_token_ids: set[int] = field(init=False, default_factory=set) - _base_dir_real: Optional[str] = field(init=False, default=None) def __post_init__(self) -> None: if self.return_tensors != "pt": @@ -61,112 +55,35 @@ def __post_init__(self) -> None: "tokenize inconsistently." ) self._image_family_token_ids = set(self.image_token_spec.image_family_token_ids) - if self.image_base_dir is not None: - self._base_dir_real = os.path.realpath(self.image_base_dir) - - def _resolve_image_path(self, p: str) -> str: - if not isinstance(p, str): - raise ValueError(f"Image path must be str, got {type(p).__name__}.") - if "\x00" in p: - raise ValueError("Image path contains embedded NUL byte.") - p_lower = p.lower() - if p_lower.startswith( - ( - "http://", - "https://", - "ftp://", - "ftps://", - "file://", - "data:", - "s3://", - "gs://", - "gcs://", - "az://", - "azure://", - "hf://", - ) - ) or p.startswith(("\\\\", "//")): - raise ValueError( - f"Non-local image path scheme is not supported in v1 " - f"multimodal CPT (got {p!r})." - ) - if self._base_dir_real is not None: - if os.path.isabs(p): - raise ValueError( - f"Absolute image path {p!r} is rejected when " - f"`image_base_dir` is configured. All image paths must be " - f"relative to the configured base directory." - ) - resolved = os.path.realpath(os.path.join(self._base_dir_real, p)) - # commonpath (not startswith) so root-dir bases like "/" work. - try: - within_base = ( - os.path.commonpath([self._base_dir_real, resolved]) - == self._base_dir_real - ) - except ValueError: - within_base = False - if not within_base: - raise ValueError( - f"Image path {p!r} resolves outside `image_base_dir` " - f"after symlink resolution. Refusing to load." - ) - return resolved - return os.path.realpath(p) if os.path.isabs(p) else p - def _open_image_hardened(self, resolved: str) -> Image.Image: - # O_NOFOLLOW closes the realpath→open TOCTOU window for the final component. - nofollow = getattr(os, "O_NOFOLLOW", 0) - try: - fd = os.open(resolved, os.O_RDONLY | nofollow) - except OSError as exc: - raise ValueError( - f"Cannot open image (os.open failed: {type(exc).__name__})." - ) from exc - file_obj = os.fdopen(fd, "rb") - try: - with Image.open(file_obj) as src: - w, h = src.size - if w * h > self.max_image_pixels: - raise ValueError( - f"Image pixels ({w}x{h}) exceed " - f"max_image_pixels ({self.max_image_pixels})." - ) - # Multi-frame bomb guard (GIF/TIFF/WebP). - n_frames = getattr(src, "n_frames", 1) - if n_frames > 1: - raise ValueError( - f"Multi-frame images are not supported (got {n_frames} frames)." - ) - img = src.convert("RGB") - img.load() - return img - finally: - if not file_obj.closed: - file_obj.close() + def _resolve_image_source(self, src: Any) -> Any: + # Only join base_dir for relative string paths; pass everything else + # (PIL images, URLs, base64, absolute paths) through to load_image. + if ( + self.image_base_dir + and isinstance(src, str) + and not os.path.isabs(src) + and "://" not in src + ): + return os.path.join(self.image_base_dir, src) + return src def _load_images_for_row( - self, paths: list[str], row_index: int + self, sources: list, row_index: int ) -> list[Image.Image]: - if len(paths) > self.max_images_per_row: - raise ValueError( - f"Row {row_index}: {len(paths)} images exceeds " - f"`max_images_per_row={self.max_images_per_row}`. Split the " - f"row or raise the cap if this is expected." - ) out: list[Image.Image] = [] - for raw in paths: + for raw in sources: try: - resolved = self._resolve_image_path(raw) - img = self._open_image_hardened(resolved) + img = load_image(self._resolve_image_source(raw)) except Exception as exc: - # Top-level log gets basename only; full path stays on DEBUG. - basename = os.path.basename(str(raw)) + label = ( + os.path.basename(raw) if isinstance(raw, str) else type(raw).__name__ + ) msg = ( - f"Row {row_index}: failed to load image {basename!r} " + f"Row {row_index}: failed to load image {label!r} " f"({type(exc).__name__})" ) - LOG.debug("failed image full path: %r; error: %s", raw, exc) + LOG.debug("failed image full source: %r; error: %s", raw, exc) if self.skip_bad_images: LOG.warning("%s — skipping", msg) continue @@ -197,33 +114,27 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: ) raw = ex["images"] if raw is None: - raw_paths: list[str] = [] + raw_sources: list = [] elif isinstance(raw, (list, tuple)): - raw_paths = list(raw) + raw_sources = list(raw) else: raise TypeError( f"Row {i}: `images` must be a list (or None), got " f"{type(raw).__name__}." ) - for j, rp in enumerate(raw_paths): - if not isinstance(rp, str): - raise TypeError( - f"Row {i}, image {j}: path must be str, got " - f"{type(rp).__name__}." - ) # Processor re-tokenizes below, discarding the encoder's EOS — re-append. if self.add_eos_token and self.tokenizer.eos_token: if not mm_text.endswith(self.tokenizer.eos_token): mm_text = mm_text + self.tokenizer.eos_token texts.append(mm_text) - loaded = self._load_images_for_row(raw_paths, row_index=i) - if self.skip_bad_images and len(loaded) != len(raw_paths): + loaded = self._load_images_for_row(raw_sources, row_index=i) + if self.skip_bad_images and len(loaded) != len(raw_sources): # Drop the row to avoid silent placeholder/image count mismatch. LOG.warning( "Row %d: %d/%d images failed to load; dropping row.", i, - len(raw_paths) - len(loaded), - len(raw_paths), + len(raw_sources) - len(loaded), + len(raw_sources), ) texts.pop() continue @@ -326,8 +237,8 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: if self.max_length is not None and input_ids_len > self.max_length: LOG.warning( "Batch input_ids length %d exceeds configured sequence_len %d " - "(image placeholder expansion). Reduce max_images_per_row or " - "raise sequence_len if this fires repeatedly.", + "(image placeholder expansion). Raise sequence_len if this " + "fires repeatedly.", input_ids_len, self.max_length, ) diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py index 458f5ac02a..dca32d66b8 100644 --- a/tests/test_multimodal_streaming.py +++ b/tests/test_multimodal_streaming.py @@ -432,94 +432,7 @@ def test_collator_raises_on_missing_columns(smolvlm_processor): collator.torch_call([{"input_ids": [1, 2, 3]}]) # no _mm_text / images -# ---- security gates ------------------------------------------------------- - - -def test_collator_rejects_path_traversal_with_base_dir( - smolvlm_processor, two_tiny_images, tmp_path -): - spec = build_image_token_spec(smolvlm_processor) - base = tmp_path / "images" - base.mkdir() - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - image_base_dir=str(base), - ) - # Absolute path rejection - with pytest.raises(RuntimeError) as exc: - collator._load_images_for_row([str(two_tiny_images[0])], row_index=0) - assert isinstance(exc.value.__cause__, ValueError) - assert "Absolute image path" in str(exc.value.__cause__) - # Containment-escape rejection - with pytest.raises(RuntimeError) as exc: - collator._load_images_for_row(["../../../etc/passwd"], row_index=0) - assert isinstance(exc.value.__cause__, ValueError) - assert "outside" in str(exc.value.__cause__) - - -def test_collator_rejects_remote_urls(smolvlm_processor): - spec = build_image_token_spec(smolvlm_processor) - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - ) - for url in ( - "http://example.com/a.png", - "https://x/y.jpg", - "file:///etc/passwd", - "ftp://x/y.png", - "data:image/png;base64,xxx", - # Cloud / object-store / hub URIs. - "s3://bucket/key.png", - "gs://bucket/key.png", - "gcs://bucket/key.png", - "az://container/key.png", - "azure://account/container/key.png", - "hf://datasets/foo/bar/img.png", - # Case-variant bypass attempts. - "HTTP://evil.com/x.png", - "Https://x/y.jpg", - "FILE:///etc/passwd", - "DATA:image/png;base64,xxx", - "S3://bucket/key.png", - "GS://bucket/key.png", - ): - with pytest.raises(RuntimeError) as exc: - collator._load_images_for_row([url], row_index=0) - assert isinstance(exc.value.__cause__, ValueError) - assert "Non-local image path scheme" in str(exc.value.__cause__) - - -def test_collator_rejects_nul_byte_paths(smolvlm_processor): - spec = build_image_token_spec(smolvlm_processor) - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - ) - with pytest.raises(RuntimeError) as exc: - collator._load_images_for_row(["bad\x00path.png"], row_index=0) - assert "NUL byte" in str(exc.value.__cause__) - - -def test_collator_rejects_non_string_image_entries(smolvlm_processor, two_tiny_images): - spec = build_image_token_spec(smolvlm_processor) - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - ) - rows = [ - { - "_mm_text": f"{spec.image_token}\nrow", - "images": [None], # type: ignore[list-item] - } - ] - with pytest.raises(TypeError, match="path must be str"): - collator.torch_call(rows) +# ---- input validation ----------------------------------------------------- def test_collator_rejects_bytes_mm_text(smolvlm_processor, two_tiny_images): @@ -555,19 +468,6 @@ def test_collator_sanitizes_error_message(smolvlm_processor, tmp_path): assert "Row 3" in str(exc.value) -def test_collator_rejects_too_many_images(smolvlm_processor, two_tiny_images): - spec = build_image_token_spec(smolvlm_processor) - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - max_images_per_row=2, - ) - paths = [str(two_tiny_images[0])] * 3 - with pytest.raises(ValueError, match="max_images_per_row"): - collator._load_images_for_row(paths, row_index=0) - - def test_collator_skip_bad_images_drops_row_and_continues( smolvlm_processor, two_tiny_images, tmp_path ): @@ -615,24 +515,6 @@ def test_collator_all_rows_dropped_raises(smolvlm_processor, tmp_path): collator.torch_call(rows) -def test_collator_rejects_multi_frame_image(smolvlm_processor, tmp_path): - """Multi-frame GIF is rejected by the in-process bomb guard.""" - spec = build_image_token_spec(smolvlm_processor) - collator = MultiModalPretrainDataCollator( - tokenizer=smolvlm_processor.tokenizer, - processor=smolvlm_processor, - image_token_spec=spec, - ) - # Build an in-memory 2-frame GIF; verify the inner guard fires (the - # outer `_load_images_for_row` wraps it in a basename-only RuntimeError). - gif_path = tmp_path / "anim.gif" - f0 = Image.new("RGB", (16, 16), color=(255, 0, 0)) - f1 = Image.new("RGB", (16, 16), color=(0, 255, 0)) - f0.save(gif_path, save_all=True, append_images=[f1], duration=100, loop=0) - with pytest.raises(ValueError, match="Multi-frame"): - collator._open_image_hardened(str(gif_path)) - - # ---- mixed / all-text batches -------------------------------------------- From c962b72dbaea8a8e11292762f240bcc9ef6363ae Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Wed, 20 May 2026 19:27:45 -0700 Subject: [PATCH 12/14] chore(lint): ruff-format mm_pretrain.py --- src/axolotl/utils/collators/mm_pretrain.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index 1a1051224e..142dd5f41a 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -68,16 +68,16 @@ def _resolve_image_source(self, src: Any) -> Any: return os.path.join(self.image_base_dir, src) return src - def _load_images_for_row( - self, sources: list, row_index: int - ) -> list[Image.Image]: + def _load_images_for_row(self, sources: list, row_index: int) -> list[Image.Image]: out: list[Image.Image] = [] for raw in sources: try: img = load_image(self._resolve_image_source(raw)) except Exception as exc: label = ( - os.path.basename(raw) if isinstance(raw, str) else type(raw).__name__ + os.path.basename(raw) + if isinstance(raw, str) + else type(raw).__name__ ) msg = ( f"Row {row_index}: failed to load image {label!r} " From 389e0f28aed2c9ccf37eefb1bfe20a10dc015a27 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Fri, 22 May 2026 07:56:06 -0700 Subject: [PATCH 13/14] refactor(mm-cpt): scope PR to streaming path; dedupe eval schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The non-streaming `datasets:` MM CPT route was never wired through `build_collator`, which only routes MM batches under the pretraining branch — `datasets:` entries would emit `images`/`_mm_text` rows into a text-only collator. Strip the strategy class + `load()` and their unit tests; keep `ImageTokenSpec`, `build_image_token_spec`, and `check_processor_compatibility` since the streaming collator imports them. Add a docs callout that only the streaming `pretraining_dataset` route is currently wired. Fold `MultiModalEvalDataset` into `PretrainingDataset` via inheritance; the only intentional divergence is the `type` default and the `_require_mm_markers` validator. Drops ~60 lines of duplicated `Field` declarations the reviewer flagged. Tighten the collator `KeyError` message to mention only `encode_streaming_multimodal` now that the strategy class is gone. --- docs/multimodal.qmd | 9 + .../prompt_strategies/multimodal_pretrain.py | 140 +-------------- src/axolotl/utils/collators/mm_pretrain.py | 3 +- src/axolotl/utils/schemas/datasets.py | 47 +---- .../test_multimodal_pretrain.py | 166 +----------------- 5 files changed, 29 insertions(+), 336 deletions(-) diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 8d56b6a4a5..a8e42d3ffc 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -372,6 +372,15 @@ Intended for use cases like OCR/transcription corpora where every row is a tight `(image, target_text)` pair and any user/assistant framing would pollute the learned signal. +::: {.callout-note} +**Currently supported via the streaming `pretraining_dataset` route only.** +A non-streaming `datasets:` route (`type: multimodal_pretrain` under +`datasets:`) is intentionally not wired — the collator selection inside +`build_collator` only routes MM CPT batches under the pretraining branch. +Configure MM CPT under `pretraining_dataset:` with `streaming: true` as +shown below. +::: + ### Dataset format (JSONL) Two keys per row: `text` (the raw string) and `images` (list of local paths). diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py index 13bb77030a..f43eb02fec 100644 --- a/src/axolotl/prompt_strategies/multimodal_pretrain.py +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -1,16 +1,17 @@ -"""Multimodal CPT tokenization strategy.""" +"""Multimodal CPT helpers (image-token autodetection + processor compat). + +Only the streaming `pretraining_dataset` route is wired in v1; the +non-streaming `datasets:` route (strategy class + `load()`) is deferred to a +follow-on PR that also wires `build_collator` to route MM CPT batches outside +the `training_args.pretraining` branch. +""" from __future__ import annotations from dataclasses import dataclass -from typing import Any -from transformers import BatchEncoding, PreTrainedTokenizerBase, ProcessorMixin +from transformers import ProcessorMixin -from axolotl.prompt_strategies.pretrain import ( - PretrainTokenizationStrategy, - PretrainTokenizer, -) from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -204,128 +205,3 @@ def check_processor_compatibility(processor: ProcessorMixin) -> None: raise ValueError( f"Multimodal CPT is not supported for {base_cls.__name__}: {reason}" ) - - -class MultimodalPretrainTokenizationStrategy(PretrainTokenizationStrategy): - def __init__( - self, - *args: Any, - image_token: str, - image_token_id: int, - image_column: str = "images", - image_base_dir: str | None = None, - image_token_spec: ImageTokenSpec | None = None, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.image_token = image_token - self.image_token_id = image_token_id - self.image_column = image_column - self.image_base_dir = image_base_dir - self.image_token_spec = image_token_spec - - def _tokenize( - self, - prompt: str, - add_eos_token: bool = True, - strip_bos_token: bool = False, - ) -> BatchEncoding: - # No truncation: collator re-tokenizes the full text without truncation; - # truncating here decouples the stored ids from what the model receives. - res = self.tokenizer(prompt, add_special_tokens=True) - ids = list(res["input_ids"]) - mask = list(res["attention_mask"]) - bos_id = getattr(self.tokenizer, "bos_token_id", None) - if strip_bos_token and ids and bos_id is not None and ids[0] == bos_id: - ids = ids[1:] - mask = mask[1:] - eos_id = getattr(self.tokenizer, "eos_token_id", None) - if add_eos_token and eos_id is not None: - ids = ids + [eos_id] - mask = mask + [1] - res["input_ids"] = [ids] - res["attention_mask"] = [mask] - return res - - def tokenize_prompt(self, prompt: dict[str, Any]) -> dict[str, list]: - text = prompt[self.text_column] - raw_images = prompt.get(self.image_column) - if raw_images is None: - images: list = [] - elif isinstance(raw_images, (list, tuple)): - images = list(raw_images) - else: - raise ValueError( - f"Row's `{self.image_column}` must be a list of image paths, " - f"got {type(raw_images).__name__}." - ) - - res = self._tokenize(text) - ids = res["input_ids"][0] - # Count by token id — `text.count` substring-matches `` in ``. - n_placeholders = sum(1 for t in ids if t == self.image_token_id) - if n_placeholders != len(images): - raise ValueError( - f"Multimodal CPT row has {n_placeholders} occurrence(s) of " - f"{self.image_token!r} in text but {len(images)} image path(s) " - f"in `{self.image_column}`. They must match — the text column " - f"must contain exactly one placeholder per image." - ) - if len(ids) > self.max_length: - raise ValueError( - f"Multimodal CPT row tokenizes to {len(ids)} tokens which " - f"exceeds sequence_len={self.max_length}. Pre-chunk your text " - f"or raise sequence_len (image patch expansion at the " - f"processor may push the final length even higher)." - ) - - # `_tokenize` produces exactly one chunk; the assert keeps that - # invariant explicit so a future change there can't silently - # mis-align `images` / `_mm_text` against `input_ids`. - assert len(res["input_ids"]) == 1 - res["images"] = [list(images)] - res["_mm_text"] = [text] - return res - - -def load( - tokenizer: PreTrainedTokenizerBase, - cfg: Any, - ds_cfg: dict | None = None, - processor: ProcessorMixin | None = None, -) -> MultimodalPretrainTokenizationStrategy: - if processor is None: - raise ValueError( - "multimodal_pretrain requires a processor. Set `processor_type: " - "AutoProcessor` (or the concrete processor class) in your config " - "so axolotl loads it at startup." - ) - check_processor_compatibility(processor) - - ds_cfg = dict(ds_cfg or {}) - text_column = ds_cfg.get("text_column") or ds_cfg.get("field") or "text" - image_column = ds_cfg.get("image_column") or "images" - image_base_dir = ds_cfg.get("image_base_dir") - image_token_override = ds_cfg.get("image_token") - - spec = build_image_token_spec(processor, override=image_token_override) - LOG.info( - f"multimodal_pretrain: placeholder={spec.image_token!r} " - f"(id={spec.image_token_id}), masking {len(spec.image_family_token_ids)} " - f"image-family token ids in labels" - ) - - strat = MultimodalPretrainTokenizationStrategy( - PretrainTokenizer(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - text_column=text_column, - image_column=image_column, - image_base_dir=image_base_dir, - image_token=spec.image_token, - image_token_id=spec.image_token_id, - max_length=cfg.sequence_len, - image_token_spec=spec, - ) - return strat diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py index 142dd5f41a..efa0319c71 100644 --- a/src/axolotl/utils/collators/mm_pretrain.py +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -102,8 +102,7 @@ def torch_call(self, examples: list[dict]) -> dict[str, Any]: raise KeyError( f"MultiModalPretrainDataCollator: row {i} is missing " f"'_mm_text' or 'images'. Did you wire the multimodal CPT " - f"encoder (encode_streaming_multimodal or " - f"MultimodalPretrainTokenizationStrategy)?" + f"encoder (encode_streaming_multimodal)?" ) mm_text = ex["_mm_text"] if not isinstance(mm_text, str): diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index ad186d20db..367133ac38 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -272,54 +272,15 @@ class PretrainingDataset(BaseModel): ) -class MultiModalEvalDataset(BaseModel): +class MultiModalEvalDataset(PretrainingDataset): """Multimodal CPT eval dataset configuration (test_datasets entry). - Use type='multimodal_pretrain' (or multimodal=True). The dataset must - expose a text column and a list[str] image-paths column; their names - default to 'text' and 'images' and can be overridden per-entry. + Inherits all fields from :class:`PretrainingDataset` and only overrides + the `type` default (no `"pretrain"` fallback for eval entries) plus adds + a validator that requires an explicit MM marker. """ - path: str | None = None - name: str | None = None - split: str | None = "train" - data_files: str | list[str] | None = None - ds_type: str | None = Field( - default=None, - json_schema_extra={ - "description": "Dataset loader type when `path` points to local files (e.g. 'json', 'csv', 'parquet')." - }, - ) - skip: int | None = None type: str | None = None - trust_remote_code: bool | None = False - - multimodal: bool | None = Field( - default=None, - json_schema_extra={ - "description": "Opt in to multimodal eval. Auto-enabled when type='multimodal_pretrain'." - }, - ) - text_column: str | None = Field( - default="text", - json_schema_extra={"description": "Column holding the row's text."}, - ) - image_column: str | None = Field( - default="images", - json_schema_extra={ - "description": "Column holding a list of image paths per row." - }, - ) - image_base_dir: str | None = Field( - default=None, - json_schema_extra={"description": "Base directory for relative image paths."}, - ) - image_token: str | None = Field( - default=None, - json_schema_extra={ - "description": "Override the image placeholder token (autodetected from processor if unset)." - }, - ) @model_validator(mode="before") @classmethod diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py index 3bdec1b273..da6e6a4985 100644 --- a/tests/prompt_strategies/test_multimodal_pretrain.py +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -1,24 +1,22 @@ -"""Multimodal CPT prompt strategy + safety gate tests.""" +"""Multimodal CPT helpers + safety gate tests. -from __future__ import annotations +The non-streaming strategy class and ``load()`` factory are deferred to a +follow-on PR (along with the matching ``build_collator`` routing for +``datasets:`` MM CPT batches), so only the helper-level surface is exercised +here in v1. +""" -from pathlib import Path -from typing import Any +from __future__ import annotations -import numpy as np import pytest -from PIL import Image from transformers import AutoProcessor from axolotl.prompt_strategies.multimodal_pretrain import ( _INCOMPATIBLE_PROCESSOR_REASONS, ImageTokenSpec, - MultimodalPretrainTokenizationStrategy, build_image_token_spec, check_processor_compatibility, - load, ) -from axolotl.prompt_strategies.pretrain import PretrainTokenizer from tests.hf_offline_utils import enable_hf_offline @@ -33,15 +31,6 @@ def fixture_smolvlm_processor( return AutoProcessor.from_pretrained(_SMOLVLM) -@pytest.fixture(scope="module", name="tiny_image_path") -def fixture_tiny_image_path(tmp_path_factory) -> Path: - d = tmp_path_factory.mktemp("mm_pretrain_imgs") - p = d / "dummy.png" - arr = np.random.default_rng(0).integers(0, 255, (64, 64, 3)).astype("uint8") - Image.fromarray(arr).save(p) - return p - - # ---- build_image_token_spec ------------------------------------------------ @@ -124,144 +113,3 @@ class CustomUserProcessor(BaseMllama): def test_check_processor_compatibility_accepts_supported(smolvlm_processor): check_processor_compatibility(smolvlm_processor) - - -# ---- MultimodalPretrainTokenizationStrategy -------------------------------- - - -def _make_strategy( - smolvlm_processor: Any, - text_column: str = "text", - image_column: str = "images", -) -> MultimodalPretrainTokenizationStrategy: - spec = build_image_token_spec(smolvlm_processor) - return MultimodalPretrainTokenizationStrategy( - PretrainTokenizer(), - smolvlm_processor.tokenizer, - False, # train_on_inputs - 2048, # sequence_len - text_column=text_column, - image_column=image_column, - image_base_dir=None, - image_token=spec.image_token, - image_token_id=spec.image_token_id, - max_length=2048, - ) - - -def test_strategy_preserves_images_and_text(smolvlm_processor, tiny_image_path): - strat = _make_strategy(smolvlm_processor) - out = strat.tokenize_prompt( - { - "text": "\nsample transcription text", - "images": [str(tiny_image_path)], - } - ) - assert "input_ids" in out - assert "images" in out and "_mm_text" in out - assert len(out["input_ids"]) == 1 - assert len(out["images"]) == 1 - assert len(out["_mm_text"]) == 1 - assert out["images"][0] == [str(tiny_image_path)] - assert out["_mm_text"][0].startswith("") - - -def test_strategy_rejects_placeholder_count_mismatch( - smolvlm_processor, tiny_image_path -): - strat = _make_strategy(smolvlm_processor) - with pytest.raises(ValueError, match="occurrence"): - strat.tokenize_prompt( - { - "text": "\ntwo placeholders one image", - "images": [str(tiny_image_path)], - } - ) - - -def test_strategy_rejects_row_exceeding_max_length(smolvlm_processor, tiny_image_path): - spec = build_image_token_spec(smolvlm_processor) - strat = MultimodalPretrainTokenizationStrategy( - PretrainTokenizer(), - smolvlm_processor.tokenizer, - False, - 128, - text_column="text", - image_column="images", - image_base_dir=None, - image_token=spec.image_token, - image_token_id=spec.image_token_id, - max_length=128, - ) - huge = "word " * 5000 - with pytest.raises(ValueError, match="exceeds sequence_len"): - strat.tokenize_prompt( - { - "text": f"{spec.image_token} {huge}", - "images": [str(tiny_image_path)], - } - ) - - -def test_strategy_rejects_non_list_image_column(smolvlm_processor, tiny_image_path): - strat = _make_strategy(smolvlm_processor) - with pytest.raises(ValueError, match="list"): - strat.tokenize_prompt( - { - "text": "\nbad image field", - "images": str(tiny_image_path), # should be a list - } - ) - - -@pytest.mark.parametrize("bad_value", ["", 0, False]) -def test_strategy_rejects_falsy_non_none_image_column(smolvlm_processor, bad_value): - """Falsy non-None image cells (e.g. "") are rejected, not coerced to [].""" - strat = _make_strategy(smolvlm_processor) - with pytest.raises(ValueError, match="list"): - strat.tokenize_prompt( - { - "text": "no placeholder, but bad images cell", - "images": bad_value, - } - ) - - -def test_strategy_treats_none_image_column_as_empty(smolvlm_processor): - """images=None is the only falsy value treated as a text-only row.""" - strat = _make_strategy(smolvlm_processor) - out = strat.tokenize_prompt( - { - "text": "plain text-only row, no placeholder", - "images": None, - } - ) - assert out["images"][0] == [] - - -# ---- load() factory -------------------------------------------------------- - - -def test_load_requires_processor(smolvlm_processor): - class _Cfg: - train_on_inputs = False - sequence_len = 2048 - - with pytest.raises(ValueError, match="processor"): - load(smolvlm_processor.tokenizer, _Cfg(), ds_cfg={}, processor=None) - - -def test_load_returns_strategy_with_spec(smolvlm_processor): - class _Cfg: - train_on_inputs = False - sequence_len = 2048 - - strat = load( - smolvlm_processor.tokenizer, - _Cfg(), - ds_cfg={"text_column": "text", "image_column": "images"}, - processor=smolvlm_processor, - ) - assert isinstance(strat, MultimodalPretrainTokenizationStrategy) - assert hasattr(strat, "image_token_spec") - assert strat.image_token_spec.image_token == "" From 9867132f0dda36d5d65e530501642b2ff6fd4151 Mon Sep 17 00:00:00 2001 From: thad0ctor Date: Mon, 25 May 2026 02:00:33 -0700 Subject: [PATCH 14/14] fix(mm-cpt): raise ValueError when multimodal_pretrain is used under datasets: Previously failed with a raw AttributeError at strategy load time. Now raises a small ValueError pointing users to the supported entry point. --- src/axolotl/prompt_strategies/multimodal_pretrain.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py index f43eb02fec..c0bcb6bb50 100644 --- a/src/axolotl/prompt_strategies/multimodal_pretrain.py +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -17,6 +17,13 @@ LOG = get_logger(__name__) +def load(*_args, **_kwargs): + raise ValueError( + "multimodal_pretrain is only supported via pretraining_dataset " + "with streaming: true — see docs/multimodal.qmd" + ) + + def _get_incompatible_processor_classes() -> tuple[type, ...]: import importlib