diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 5197f48b10..a8e42d3ffc 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -364,6 +364,146 @@ Here is an example of a multi-modal dataset: ] ``` +## Continued Pre-training (CPT) with images {#sec-multimodal-cpt} + +Raw image+text continued pretraining — no chat template, no conversational +scaffolding. The model learns to emit raw text conditioned on visual patches. +Intended for use cases like OCR/transcription corpora where every row is a +tight `(image, target_text)` pair and any user/assistant framing would pollute +the learned signal. + +::: {.callout-note} +**Currently supported via the streaming `pretraining_dataset` route only.** +A non-streaming `datasets:` route (`type: multimodal_pretrain` under +`datasets:`) is intentionally not wired — the collator selection inside +`build_collator` only routes MM CPT batches under the pretraining branch. +Configure MM CPT under `pretraining_dataset:` with `streaming: true` as +shown below. +::: + +### Dataset format (JSONL) + +Two keys per row: `text` (the raw string) and `images` (list of local paths). +The `text` must contain the model's placeholder token **once per image**, +placed immediately before the text it describes, followed by a newline: + +```json +{"text": "\nInvoice number: 10427. Total due: 148.32 USD.", "images": ["/dataset/crops/doc_14_p2.png"]} +{"text": "\nChapter 3 begins with a description of the storm over the harbor.", "images": ["/dataset/crops/doc_14_p3.png"]} +``` + +Notes: + +- Never wrap the row in `User:` / `Assistant:` / `Transcribe this:` scaffolding. That is the whole point of the CPT path. +- Do not manually append an EOS token. Axolotl appends one during tokenization. +- The newline between the placeholder and the real text preserves the BPE + boundary. Without it, some tokenizers merge the visual-token boundary with + the first real character. + +### The placeholder token varies by model + +| Model family | Placeholder | Notes | +|---|---|---| +| LLaVA-1.5 / 1.6 | `` | | +| SmolVLM / SmolVLM2 / Idefics3 | `` | Processor expands to 1088 tokens (17 tiles × 64) | +| Qwen2-VL / Qwen2.5-VL / Qwen3-VL | `<\|image_pad\|>` | Processor autowraps with `<\|vision_start\|>` / `<\|vision_end\|>` | +| Gemma-3 | `` | Processor expands to 256 `` | +| Gemma-4 | `<\|image\|>` | Processor expands to 256 `<\|image\|>` | + +Axolotl autodetects the placeholder from the loaded processor. If autodetection +fails, supply `image_token: ` on the dataset entry. + +### YAML example + +```yaml +base_model: HuggingFaceTB/SmolVLM-500M-Instruct +processor_type: AutoProcessor + +pretraining_dataset: + - path: /path/to/shards/*.jsonl + ds_type: json + type: multimodal_pretrain + text_column: text + image_column: images + image_base_dir: /path/to/images # optional, for relative paths + # image_token: "" # optional override; autodetect by default + +streaming: true +sequence_len: 2048 +sample_packing: false # REQUIRED — see below +remove_unused_columns: false # auto-set by validator + +max_steps: 10000 +micro_batch_size: 1 +gradient_accumulation_steps: 8 +``` + +### Eval datasets + +`test_datasets` accepts multimodal entries (`type: multimodal_pretrain` or +`multimodal: true`). Per-entry `text_column` and `image_column` are honored +independently. When more than one multimodal entry is provided, +`image_base_dir` and `image_token` must be either unset on every entry or +identical across them, because the eval collator resolves both once for the +merged eval stream. + +### Mixed image / text-only batches + +Rows with no images are allowed. All-text batches bypass the processor and +tokenize via the base tokenizer (no `pixel_values`). Mixed batches are +handed to the processor as-is. + +Verified VLM families: + +- HuggingFace `SmolVLM` / `SmolVLM2` +- Google `Gemma-3` (4B), `Gemma-4` (E2B, 31B) +- Alibaba `Qwen2.5-VL`, `Qwen3-VL` + +For other families (LLaVA, Qwen2-VL, Idefics3), the collator's per-row +retry isolates and names the offending row on processor failure — failures +are loud, not silent. Smoke-test before long runs, or pre-filter rows +without images. + +### Gates and rejections + +The following combinations are rejected at config-load time with a clear error: + +- `sample_packing: true` — cross-row packing would break the 1-to-1 alignment + between text placeholders and `pixel_values`. +- `chat_template` set to anything — defeats the purpose of the CPT path. +- `processor_type` unset — no processor means no image tensors. +- Multiple MM `test_datasets` entries with mismatched `image_base_dir` or `image_token`. + +In addition, the following model families are **not supported** in v1 and will +be rejected when their processor is loaded: + +- **Llama-3.2-Vision (Mllama)** — uses cross-attention image injection, not + in-stream placeholders. Use chat-template SFT. +- **Pixtral** — requires `mistral_common` and a different API. +- **InternVL** — ships a custom processor that doesn't produce `pixel_values`. + +Per-row validation: at encode time the row's text is tokenized once and the +number of `image_token_id` occurrences in the resulting token-id list must +equal `len(images)`. Counting by token id (not by substring) avoids false +matches — e.g., `` would substring-match inside ``. +This is a critical guardrail — LLaVA and Qwen-VL processors silently +accept rows without placeholders and drop the image, which looks like +successful training but teaches nothing. If a row fails this check, +inspect the tokenized ids rather than the raw string. + +### Why masking image tokens in labels is automatic + +For this multimodal CPT path, the collator masks every image-family token id +(``, `<\|image_pad\|>`, `<\|vision_start\|>`, `<\|vision_end\|>`, +``, ``, ``, `<\|image\|>`, +etc.) to `-100` in the labels tensor. + +Supported processors expand image placeholders into vision-specific token ids. +If those ids contribute to loss, the model is trained to predict patch or +marker tokens rather than only the target text. On architectures like +Qwen-VL and SmolVLM, leaving them unmasked substantially increases loss and can +destabilize training. + ## FAQ 1. `PIL.UnidentifiedImageError: cannot identify image file ...` diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 93d751f8ba..51cb28bd3a 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -44,6 +44,7 @@ V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.collators.mm_pretrain import MultiModalPretrainDataCollator from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger @@ -67,6 +68,27 @@ def _warn_if_num_workers_zero_for_mm(cfg, log) -> None: ) +def _is_multimodal_cpt(cfg) -> bool: + if not getattr(cfg, "pretraining_dataset", None): + return False + ds_first = cfg.pretraining_dataset[0] + ds_type = None + mm_flag = None + if hasattr(ds_first, "type"): + ds_type = getattr(ds_first, "type", None) + mm_flag = getattr(ds_first, "multimodal", None) + elif isinstance(ds_first, dict): + ds_type = ds_first.get("type") + mm_flag = ds_first.get("multimodal") + return (ds_type == "multimodal_pretrain") or bool(mm_flag) + + +def _mm_cpt_get(pt_cfg, key, default=None): + if isinstance(pt_cfg, dict): + return pt_cfg.get(key, default) + return getattr(pt_cfg, key, default) + + class HFCausalTrainerBuilder(TrainerBuilderBase): """ Build the HuggingFace training args/trainer for causal models and reward modeling @@ -464,6 +486,36 @@ def build(self, total_num_steps): return trainer + def _build_mm_pretrain_collator(self, pad_to_multiple_of=None, is_eval=False): + from axolotl.prompt_strategies.multimodal_pretrain import ( + build_image_token_spec, + ) + + if is_eval and self.cfg.test_datasets: + pt_cfg = self.cfg.test_datasets[0] + elif self.cfg.pretraining_dataset: + pt_cfg = self.cfg.pretraining_dataset[0] + else: + pt_cfg = {} + spec = build_image_token_spec( + self.processor, override=_mm_cpt_get(pt_cfg, "image_token") + ) + max_length = ( + self.cfg.eval_sequence_len + if is_eval and self.cfg.eval_sequence_len + else self.cfg.sequence_len + ) + collator_kwargs = { + "tokenizer": self.tokenizer, + "processor": self.processor, + "image_token_spec": spec, + "image_base_dir": _mm_cpt_get(pt_cfg, "image_base_dir"), + "max_length": max_length, + } + if pad_to_multiple_of is not None: + collator_kwargs["pad_to_multiple_of"] = pad_to_multiple_of + return MultiModalPretrainDataCollator(**collator_kwargs) + def build_collator( self, training_args, # type: "AxolotlTrainingArguments" # type: ignore @@ -471,6 +523,16 @@ def build_collator( **kwargs, ): if training_args.pretraining: + # Intercept MM CPT before the text-only pretraining branches. + if ( + self.cfg.processor_type + and self.processor + and _is_multimodal_cpt(self.cfg) + ): + return self._build_mm_pretrain_collator( + pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), + is_eval=is_eval, + ) if ( self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1 @@ -532,6 +594,15 @@ def build_collator( else: collator = BatchSamplerDataCollatorForSeq2Seq else: + if ( + self.cfg.processor_type + and self.processor + and _is_multimodal_cpt(self.cfg) + ): + return self._build_mm_pretrain_collator( + pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), + is_eval=is_eval, + ) if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator # Mirror ChatTemplateStrategy: per-dataset masking knobs from first MM dataset, else global cfg. diff --git a/src/axolotl/prompt_strategies/multimodal_pretrain.py b/src/axolotl/prompt_strategies/multimodal_pretrain.py new file mode 100644 index 0000000000..c0bcb6bb50 --- /dev/null +++ b/src/axolotl/prompt_strategies/multimodal_pretrain.py @@ -0,0 +1,214 @@ +"""Multimodal CPT helpers (image-token autodetection + processor compat). + +Only the streaming `pretraining_dataset` route is wired in v1; the +non-streaming `datasets:` route (strategy class + `load()`) is deferred to a +follow-on PR that also wires `build_collator` to route MM CPT batches outside +the `training_args.pretraining` branch. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from transformers import ProcessorMixin + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def load(*_args, **_kwargs): + raise ValueError( + "multimodal_pretrain is only supported via pretraining_dataset " + "with streaming: true — see docs/multimodal.qmd" + ) + + +def _get_incompatible_processor_classes() -> tuple[type, ...]: + import importlib + + classes: list[type] = [] + for mod_path, name in ( + ("transformers.models.mllama", "MllamaProcessor"), + ("transformers.models.pixtral", "PixtralProcessor"), + ("transformers.models.internvl", "InternVLProcessor"), + ): + try: + mod = importlib.import_module(mod_path) + cls = getattr(mod, name, None) + if cls is not None: + classes.append(cls) + except ImportError: + continue + return tuple(classes) + + +_KNOWN_IMAGE_TOKEN_CANDIDATES: tuple[str, ...] = ( + "", + "<|image|>", + "<|image_pad|>", + "", + "", + "[IMG]", + "", +) + +# Without masking these in labels, loss blows up ~10x on Qwen/SmolVLM. +_IMAGE_FAMILY_TOKEN_CANDIDATES: tuple[str, ...] = ( + "", + "<|image|>", + "<|image_pad|>", + "", + "", + "", + "<|vision_start|>", + "<|vision_end|>", + "[IMG]", + "[IMG_END]", + "", +) + +_INCOMPATIBLE_PROCESSOR_REASONS: dict[str, str] = { + "MllamaProcessor": ( + "Llama-3.2-Vision (Mllama) uses cross-attention image injection, not " + "in-stream placeholder tokens. Multimodal CPT is incompatible with " + "this architecture; use chat-template SFT instead." + ), + "PixtralProcessor": ( + "Pixtral's tokenizer goes through mistral_common with a different " + "API surface than AutoProcessor. Multimodal CPT not supported in v1; " + "use chat-template SFT or Mistral-Small-3.1." + ), + "InternVLProcessor": ( + "InternVL ships a custom processing pipeline (AutoProcessor returns " + "text-only); no pixel_values are produced. Multimodal CPT not " + "supported in v1." + ), +} +_INCOMPATIBLE_PROCESSOR_CLASSES = _get_incompatible_processor_classes() + + +@dataclass +class ImageTokenSpec: + image_token: str + image_token_id: int + image_family_token_ids: set[int] + + +def build_image_token_spec( + processor: ProcessorMixin, override: str | None = None +) -> ImageTokenSpec: + tokenizer = getattr(processor, "tokenizer", None) + if tokenizer is None: + raise ValueError( + "Processor has no `tokenizer` attribute — multimodal CPT " + "requires a processor with a text tokenizer (e.g. one produced " + "by AutoProcessor.from_pretrained for a VLM)." + ) + + def resolve_id(tok: str) -> int | None: + tid = tokenizer.convert_tokens_to_ids(tok) + unk = getattr(tokenizer, "unk_token_id", None) + if tid is None or tid == unk: + return None + return tid + + known_special_tokens: set[str] = set() + try: + known_special_tokens |= set(tokenizer.get_added_vocab().keys()) + except Exception as exc: # noqa: BLE001 + LOG.debug( + "tokenizer.get_added_vocab() failed on %s: %s", + type(tokenizer).__name__, + exc, + ) + known_special_tokens |= set(getattr(tokenizer, "all_special_tokens", None) or []) + known_special_tokens |= set( + getattr(tokenizer, "additional_special_tokens", None) or [] + ) + + image_token: str | None = None + image_token_id: int | None = None + if override is not None: + # Reject plain words that BPE-tokenize cleanly but aren't placeholders. + if override not in known_special_tokens: + raise ValueError( + f"image_token override {override!r} is not a registered " + f"special token on this tokenizer. Pick one of the model's " + f"actual image tokens (e.g. '', '<|image_pad|>', " + f"''), or leave unset to autodetect." + ) + image_token_id = resolve_id(override) + if image_token_id is None: + raise ValueError( + f"image_token override {override!r} did not resolve to a " + f"token id (unk). Remove the override to autodetect." + ) + image_token = override + else: + proc_token = getattr(processor, "image_token", None) + # Gemma-3-style only: `image_token` is the post-expansion soft token + # (its name literally contains "soft_token"); the user-facing + # placeholder is `boi_token`. Gemma-4 reverses this — `image_token` + # IS the user-facing placeholder (`<|image|>`) and `boi_token` + # (`<|image>`) is just a bracket marker, so don't blindly swap. + boi_token = getattr(processor, "boi_token", None) + if ( + boi_token + and proc_token + and boi_token != proc_token + and boi_token in known_special_tokens + and "soft_token" in proc_token + ): + proc_token = boi_token + if proc_token is not None: + image_token_id = resolve_id(proc_token) + if image_token_id is not None: + image_token = proc_token + if image_token is None: + for cand in _KNOWN_IMAGE_TOKEN_CANDIDATES: + tid = resolve_id(cand) + if tid is not None: + image_token = cand + image_token_id = tid + break + if image_token is None: + raise ValueError( + "Could not autodetect the image placeholder token for this " + "processor. Set `image_token: ` in the dataset config " + "(e.g. '' for LLaVA, '<|image_pad|>' for Qwen-VL, " + "'' for Gemma-3)." + ) + + # Filter to registered tokens so BPE-fallback ids don't get masked. + family: set[int] = {image_token_id} # type: ignore[arg-type] + for cand in _IMAGE_FAMILY_TOKEN_CANDIDATES: + if cand != image_token and cand not in known_special_tokens: + continue + tid = resolve_id(cand) + if tid is not None: + family.add(tid) + return ImageTokenSpec( + image_token=image_token, + image_token_id=image_token_id, # type: ignore[arg-type] + image_family_token_ids=family, + ) + + +def check_processor_compatibility(processor: ProcessorMixin) -> None: + if _INCOMPATIBLE_PROCESSOR_CLASSES and isinstance( + processor, _INCOMPATIBLE_PROCESSOR_CLASSES + ): + for cls in _INCOMPATIBLE_PROCESSOR_CLASSES: + if isinstance(processor, cls): + raise ValueError( + f"Multimodal CPT is not supported for {cls.__name__}: " + f"{_INCOMPATIBLE_PROCESSOR_REASONS.get(cls.__name__, '')}" + ) + # MRO-name fallback for test fakes and unimportable concrete classes. + for base_cls in type(processor).__mro__: + reason = _INCOMPATIBLE_PROCESSOR_REASONS.get(base_cls.__name__) + if reason is not None: + raise ValueError( + f"Multimodal CPT is not supported for {base_cls.__name__}: {reason}" + ) diff --git a/src/axolotl/utils/collators/mm_pretrain.py b/src/axolotl/utils/collators/mm_pretrain.py new file mode 100644 index 0000000000..efa0319c71 --- /dev/null +++ b/src/axolotl/utils/collators/mm_pretrain.py @@ -0,0 +1,257 @@ +"""Collator for multimodal CPT.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, Union + +from PIL import Image +from torch import Tensor +from transformers import PreTrainedTokenizerBase, ProcessorMixin +from transformers.data.data_collator import DataCollatorMixin +from transformers.image_utils import load_image +from transformers.utils import PaddingStrategy + +from axolotl.prompt_strategies.multimodal_pretrain import ( + ImageTokenSpec, + check_processor_compatibility, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +@dataclass +class MultiModalPretrainDataCollator(DataCollatorMixin): + tokenizer: PreTrainedTokenizerBase + processor: ProcessorMixin + image_token_spec: ImageTokenSpec + image_base_dir: Optional[str] = None + return_tensors: Literal["pt"] = "pt" + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + max_length: Optional[int] = None + skip_bad_images: bool = False + add_eos_token: bool = True + + _image_family_token_ids: set[int] = field(init=False, default_factory=set) + + def __post_init__(self) -> None: + if self.return_tensors != "pt": + raise ValueError( + "MultiModalPretrainDataCollator only supports " + "return_tensors='pt' (in-place torch ops are used downstream)." + ) + check_processor_compatibility(self.processor) + # All-text batches use self.tokenizer; image batches use self.processor. + # If they don't share the same tokenizer instance, the two paths can + # tokenize the same text differently. + proc_tokenizer = getattr(self.processor, "tokenizer", None) + if proc_tokenizer is not None and proc_tokenizer is not self.tokenizer: + LOG.warning( + "MultiModalPretrainDataCollator.tokenizer is not " + "processor.tokenizer; all-text and image batches may " + "tokenize inconsistently." + ) + self._image_family_token_ids = set(self.image_token_spec.image_family_token_ids) + + def _resolve_image_source(self, src: Any) -> Any: + # Only join base_dir for relative string paths; pass everything else + # (PIL images, URLs, base64, absolute paths) through to load_image. + if ( + self.image_base_dir + and isinstance(src, str) + and not os.path.isabs(src) + and "://" not in src + ): + return os.path.join(self.image_base_dir, src) + return src + + def _load_images_for_row(self, sources: list, row_index: int) -> list[Image.Image]: + out: list[Image.Image] = [] + for raw in sources: + try: + img = load_image(self._resolve_image_source(raw)) + except Exception as exc: + label = ( + os.path.basename(raw) + if isinstance(raw, str) + else type(raw).__name__ + ) + msg = ( + f"Row {row_index}: failed to load image {label!r} " + f"({type(exc).__name__})" + ) + LOG.debug("failed image full source: %r; error: %s", raw, exc) + if self.skip_bad_images: + LOG.warning("%s — skipping", msg) + continue + raise RuntimeError(msg) from exc + out.append(img) + return out + + def torch_call(self, examples: list[dict]) -> dict[str, Any]: + if not examples: + raise ValueError("Empty batch passed to MultiModalPretrainDataCollator.") + + texts: list[str] = [] + images: list[list[Image.Image]] = [] + for i, ex in enumerate(examples): + if "_mm_text" not in ex or "images" not in ex: + raise KeyError( + f"MultiModalPretrainDataCollator: row {i} is missing " + f"'_mm_text' or 'images'. Did you wire the multimodal CPT " + f"encoder (encode_streaming_multimodal)?" + ) + mm_text = ex["_mm_text"] + if not isinstance(mm_text, str): + raise TypeError( + f"Row {i}: `_mm_text` must be str, got " + f"{type(mm_text).__name__}. Check dataset encoding " + f"(Parquet BINARY columns may surface as bytes)." + ) + raw = ex["images"] + if raw is None: + raw_sources: list = [] + elif isinstance(raw, (list, tuple)): + raw_sources = list(raw) + else: + raise TypeError( + f"Row {i}: `images` must be a list (or None), got " + f"{type(raw).__name__}." + ) + # Processor re-tokenizes below, discarding the encoder's EOS — re-append. + if self.add_eos_token and self.tokenizer.eos_token: + if not mm_text.endswith(self.tokenizer.eos_token): + mm_text = mm_text + self.tokenizer.eos_token + texts.append(mm_text) + loaded = self._load_images_for_row(raw_sources, row_index=i) + if self.skip_bad_images and len(loaded) != len(raw_sources): + # Drop the row to avoid silent placeholder/image count mismatch. + LOG.warning( + "Row %d: %d/%d images failed to load; dropping row.", + i, + len(raw_sources) - len(loaded), + len(raw_sources), + ) + texts.pop() + continue + images.append(loaded) + + if not texts: + raise RuntimeError( + "All rows in the batch were dropped due to image load " + "failures. Check dataset integrity." + ) + + # All-text batch: bypass the processor and tokenize directly. + if all(len(im) == 0 for im in images): + LOG.debug( + "MultiModalPretrainDataCollator: all-text batch (%d rows); " + "using tokenizer-only fallback (no pixel_values).", + len(texts), + ) + tok_kwargs: dict[str, Any] = { + "text": texts, + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + tok_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + batch = self.tokenizer(**tok_kwargs) + tok_input_ids: Tensor = batch["input_ids"] + tok_labels = tok_input_ids.clone() + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + tok_labels[tok_labels == pad_id] = -100 + for tid in self._image_family_token_ids: + tok_labels[tok_labels == tid] = -100 + batch["labels"] = tok_labels + return batch + + # No truncation: it chops input_ids mid-placeholder while pixel_values + # keep every image — silent text/pixel mismatch. We warn post-hoc instead. + proc_kwargs: dict[str, Any] = { + "text": texts, + "images": images, + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + proc_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + try: + batch = self.processor(**proc_kwargs) + except Exception as exc: + # Pinpoint the bad row; bail to "inconclusive" if retry raises a different class. + LOG.warning( + "MultiModalPretrainDataCollator: processor failed on a batch " + "of %d rows (%s); retrying each row individually to locate " + "the offender. This adds up to %d extra processor calls.", + len(texts), + type(exc).__name__, + len(texts), + ) + offender_idx: Optional[int] = None + retry_ok = True + retry_kwargs: dict[str, Any] = { + "return_tensors": self.return_tensors, + "padding": self.padding, + } + if self.pad_to_multiple_of is not None: + retry_kwargs["pad_to_multiple_of"] = self.pad_to_multiple_of + for i, (t, imgs) in enumerate(zip(texts, images, strict=True)): + try: + if len(imgs) == 0: + # Some processors reject `images=[[]]` — would mislabel + # text-only rows as the offender. + self.tokenizer(text=[t], **retry_kwargs) + else: + self.processor(text=[t], images=[imgs], **retry_kwargs) + except Exception as retry_exc: + if isinstance(retry_exc, type(exc)) or isinstance( + exc, type(retry_exc) + ): + offender_idx = i + else: + retry_ok = False + break + if offender_idx is not None: + location = f"row {offender_idx}" + elif retry_ok: + location = ( + f"batch of {len(texts)} rows " + f"(individual rows all succeed; see __cause__ for details)" + ) + else: + location = f"batch of {len(texts)} rows (retry inconclusive)" + raise RuntimeError( + f"MultiModalPretrainDataCollator: processor call failed on " + f"{location} ({type(exc).__name__}: {exc}). Common causes: " + f"placeholder token absent from the row's text, image count " + f"mismatch, or an unsupported processor class." + ) from exc + + input_ids_len = batch["input_ids"].shape[-1] + if self.max_length is not None and input_ids_len > self.max_length: + LOG.warning( + "Batch input_ids length %d exceeds configured sequence_len %d " + "(image placeholder expansion). Raise sequence_len if this " + "fires repeatedly.", + input_ids_len, + self.max_length, + ) + + input_ids: Tensor = batch["input_ids"] + labels = input_ids.clone() + + pad_id = getattr(self.tokenizer, "pad_token_id", None) + if pad_id is not None: + labels[labels == pad_id] = -100 + + # Without this, image-family ids dominate loss and blow it up ~10x. + for tid in self._image_family_token_ids: + labels[labels == tid] = -100 + + batch["labels"] = labels + return batch diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 0b2ec2b5fb..51d9db9d28 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -10,6 +10,7 @@ DatasetDict, IterableDataset, IterableDatasetDict, + concatenate_datasets, load_dataset, ) from transformers import PreTrainedTokenizer, ProcessorMixin @@ -134,7 +135,9 @@ def _prepare_streaming_dataset( """ if cfg.pretraining_dataset: dataset_config = _extract_pretraining_config(cfg) - train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + train_dataset = _load_streaming_dataset( + dataset_config, cfg, tokenizer, processor=processor + ) elif cfg.sample_packing: # TODO(djsaunde): Implement for multiple datasets dataset_config = DictDefault(cfg.datasets[0]) @@ -142,7 +145,9 @@ def _prepare_streaming_dataset( # Ensure we have a split set - default to 'train' if not specified if not hasattr(dataset_config, "split") or not dataset_config.split: dataset_config.split = "train" - train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + train_dataset = _load_streaming_dataset( + dataset_config, cfg, tokenizer, processor=processor + ) else: # Use legacy loading function for non-packed streaming datasets train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( @@ -160,35 +165,69 @@ def _prepare_streaming_dataset( # Load evaluation dataset if specified eval_dataset = None if cfg.test_datasets: - _, eval_dataset, _ = _load_and_prepare_datasets( - tokenizer, - cfg, - split="test", - processor=processor, - streaming=False, + test_dicts = [t if isinstance(t, dict) else dict(t) for t in cfg.test_datasets] + is_mm_cpt_eval = any( + t.get("type") == "multimodal_pretrain" or bool(t.get("multimodal")) + for t in test_dicts ) + if is_mm_cpt_eval: + # Modality homogeneity is enforced by check_multimodal_cpt at config + # parse time; every entry here is guaranteed to be MM. + eval_streams = [ + _load_streaming_dataset( + _pretraining_config_from_entry(entry), + cfg, + tokenizer, + processor=processor, + is_eval=True, + ) + for entry in test_dicts + ] + eval_dataset = ( + eval_streams[0] + if len(eval_streams) == 1 + else concatenate_datasets(eval_streams) + ) + else: + _, eval_dataset, _ = _load_and_prepare_datasets( + tokenizer, + cfg, + split="test", + processor=processor, + streaming=False, + ) # For streaming, we return max_steps directly from config or -1 if not set total_num_steps = cfg.max_steps if cfg.max_steps else -1 return train_dataset, eval_dataset, total_num_steps, [] +def _pretraining_config_from_entry(entry: dict) -> DictDefault: + return DictDefault( + { + "path": entry["path"], + "name": entry.get("name"), + "skip": entry.get("skip"), + "split": entry.get("split", "train"), + "data_files": entry.get("data_files"), + "ds_type": entry.get("ds_type"), + "type": entry.get("type", "pretrain"), + "text_column": entry.get("text_column", "text"), + "multimodal": entry.get("multimodal"), + "image_column": entry.get("image_column", "images"), + "image_base_dir": entry.get("image_base_dir"), + "image_token": entry.get("image_token"), + "trust_remote_code": entry.get("trust_remote_code", False), + } + ) + + def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: """Extract pretraining configuration from the main config.""" if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): - config = cfg.pretraining_dataset[0] - return DictDefault( - { - "path": config["path"], - "name": config["name"], - "skip": config["skip"], - "split": config.get("split", "train"), - "data_files": config.get("data_files"), - "type": config.get("type", "pretrain"), - } - ) + return _pretraining_config_from_entry(cfg.pretraining_dataset[0]) # Simple string path case return DictDefault( { @@ -197,13 +236,24 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: "skip": 0, "split": "train", "data_files": None, + "ds_type": None, "type": "pretrain", + "text_column": "text", + "multimodal": None, + "image_column": "images", + "image_base_dir": None, + "image_token": None, # nosec + "trust_remote_code": False, } ) def _load_streaming_dataset( - pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer + pretraining_config: DictDefault, + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None = None, + is_eval: bool = False, ) -> IterableDataset: """Load and prepare a streaming dataset for pretraining.""" # Create dataset wrapper partial function @@ -213,6 +263,7 @@ def _load_streaming_dataset( tokenizer=tokenizer, cfg=cfg, dataset_base_type=pretraining_config["type"], + processor=processor, ) # Load the actual dataset @@ -221,15 +272,30 @@ def _load_streaming_dataset( and cfg.accelerator_config.dispatch_batches and not is_local_main_process() ): - iter_dataset = _create_placeholder_dataset() + iter_dataset = _create_placeholder_dataset(pretraining_config) else: - iter_dataset = load_dataset( - pretraining_config["path"], - streaming=True, - split=pretraining_config["split"], - name=pretraining_config["name"], - data_files=pretraining_config["data_files"], - ) + ds_type = pretraining_config.get("ds_type") + if ds_type: + # ds_type names the loader (e.g. 'json'); path is the data_files glob. + iter_dataset = load_dataset( + ds_type, + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=( + pretraining_config["data_files"] or pretraining_config["path"] + ), + trust_remote_code=pretraining_config.get("trust_remote_code", False), + ) + else: + iter_dataset = load_dataset( + pretraining_config["path"], + streaming=True, + split=pretraining_config["split"], + name=pretraining_config["name"], + data_files=pretraining_config["data_files"], + trust_remote_code=pretraining_config.get("trust_remote_code", False), + ) # Apply skip if specified if pretraining_config["skip"]: @@ -242,19 +308,40 @@ def _load_streaming_dataset( tokenizer, cfg, dataset_wrapper_partial, + processor=processor, + pretraining_config=pretraining_config, + is_eval=is_eval, ) # Format for PyTorch return train_dataset.with_format("torch") -def _create_placeholder_dataset() -> IterableDataset: +def _create_placeholder_dataset( + pretraining_config: DictDefault | None = None, +) -> IterableDataset: """Create a minimal placeholder dataset for non-main processes.""" - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: - f.write("text\n") - f.write("lorem ipsum dolor sit amet\n") - f.seek(0) - return load_dataset("csv", data_files=f.name, split="train", streaming=True) + text_column = "text" + image_column: str | None = None + if pretraining_config is not None: + text_column = pretraining_config.get("text_column") or "text" + is_mm = pretraining_config.get("type") == "multimodal_pretrain" or bool( + pretraining_config.get("multimodal") + ) + if is_mm: + image_column = pretraining_config.get("image_column") or "images" + + if image_column is None: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: + f.write(f"{text_column}\n") + f.write("lorem ipsum dolor sit amet\n") + f.seek(0) + return load_dataset("csv", data_files=f.name, split="train", streaming=True) + + def _gen(): + yield {text_column: "lorem ipsum dolor sit amet", image_column: []} + + return IterableDataset.from_generator(_gen) def _load_tokenized_prepared_datasets( diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py index 8b6b8a439b..c1e1d3be63 100644 --- a/src/axolotl/utils/data/streaming.py +++ b/src/axolotl/utils/data/streaming.py @@ -7,7 +7,7 @@ import torch from datasets import Dataset from torch.utils.data import RandomSampler -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, ProcessorMixin from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.logging import get_logger @@ -176,12 +176,101 @@ def encode_streaming( return ret +def encode_streaming_multimodal( + examples: Dict[str, List], + tokenizer: PreTrainedTokenizerBase, + max_tokens: int, + image_token: str, + image_token_id: int, + text_column: str = "text", + image_column: str = "images", +) -> Dict[str, List]: + texts: List[str] = examples[text_column] + imgs_list: List[List[str]] = examples[image_column] + + if len(texts) != len(imgs_list): + raise ValueError( + f"encode_streaming_multimodal: text column has {len(texts)} rows " + f"but image column has {len(imgs_list)}" + ) + + input_ids: List[List[int]] = [] + labels: List[List[int]] = [] + attention_mask: List[List[int]] = [] + keep_images: List[List[str]] = [] + keep_text: List[str] = [] + + for text, imgs in zip(texts, imgs_list, strict=True): + if not isinstance(text, str): + raise TypeError( + f"encode_streaming_multimodal: `{text_column}` must be str, " + f"got {type(text).__name__}." + ) + if imgs is None: + imgs = [] + if not isinstance(imgs, (list, tuple)): + raise ValueError( + f"encode_streaming_multimodal: row's `{image_column}` must be " + f"a list; got {type(imgs).__name__}" + ) + for j, ip in enumerate(imgs): + if not isinstance(ip, str): + raise TypeError( + f"encode_streaming_multimodal: image {j} in row must be " + f"str, got {type(ip).__name__}." + ) + # No truncation: counting on truncated ids and storing untruncated text + # (which the collator re-tokenizes without truncation) silently produces + # oversize batches and confusing placeholder/image-count mismatches. + enc = tokenizer(text, add_special_tokens=True) + ids = list(enc["input_ids"]) + [tokenizer.eos_token_id] + mask = list(enc["attention_mask"]) + [1] + # Count by id — `text.count` substring-matches `` in ``. + n_placeholders = sum(1 for t in ids if t == image_token_id) + if n_placeholders != len(imgs): + raise ValueError( + f"Multimodal CPT row has {n_placeholders} occurrence(s) of " + f"{image_token!r} in text but {len(imgs)} image path(s). " + f"Text and image count must match (one placeholder per image)." + ) + if len(ids) > max_tokens: + raise ValueError( + f"Multimodal CPT row tokenizes to {len(ids)} tokens which " + f"exceeds sequence_len={max_tokens}. Pre-chunk your text or " + f"raise sequence_len (image patch expansion at the processor " + f"may push the final length even higher)." + ) + # Labels = ids; collator masks image-family ids after re-tokenization. + input_ids.append(ids) + labels.append(list(ids)) + attention_mask.append(mask) + keep_images.append(list(imgs)) + keep_text.append(text) + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "images": keep_images, + "_mm_text": keep_text, + } + + def wrap_streaming_dataset( dataset, tokenizer, cfg, ds_wrapper_fn, + processor: Optional[ProcessorMixin] = None, + pretraining_config=None, + is_eval: bool = False, ): + # Eval streams honor cfg.eval_sequence_len when set, else cfg.sequence_len. + effective_seq_len = ( + cfg.eval_sequence_len + if is_eval and getattr(cfg, "eval_sequence_len", None) + else cfg.sequence_len + ) if cfg.sample_packing: # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure # attention isolation between packed sequences @@ -213,17 +302,66 @@ def wrap_streaming_dataset( # NOTE: This is not reachable for SFT datasets since we use the pre-existing # loading function for non-packed streaming datasets. Refer to # _prepare_streaming_datasets in sft.py for that code path. - text_column = ( - getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text" + # Prefer the resolved per-entry config so eval (test_datasets) doesn't + # silently inherit the training entry's columns/image_token. + if pretraining_config is not None: + ds_first = pretraining_config + elif cfg.pretraining_dataset: + ds_first = cfg.pretraining_dataset[0] + else: + ds_first = {} + # Plain dicts need `.get`; pydantic/DictDefault need `getattr`. + get_ds_value = ( + ds_first.get + if isinstance(ds_first, dict) + else lambda key, default=None: getattr(ds_first, key, default) ) - encode = functools.partial( - encode_streaming, - tokenizer=tokenizer, - max_tokens=cfg.sequence_len, - text_column=text_column, - concatenate=cfg.pretraining_sample_concatenation is True, + text_column = get_ds_value("text_column", "text") or "text" + ds_type = (get_ds_value("type", None) or "").strip() + is_mm_cpt = ds_type == "multimodal_pretrain" or bool( + get_ds_value("multimodal", False) ) + if is_mm_cpt: + if processor is None: + raise ValueError( + "Multimodal CPT (type: multimodal_pretrain) requires a " + "processor. Set `processor_type: AutoProcessor` (or the " + "concrete processor class) in your config." + ) + from axolotl.prompt_strategies.multimodal_pretrain import ( + build_image_token_spec, + check_processor_compatibility, + ) + + check_processor_compatibility(processor) + spec = build_image_token_spec( + processor, + override=get_ds_value("image_token", None), + ) + image_column = get_ds_value("image_column", None) or "images" + LOG.info( + f"multimodal streaming CPT: placeholder={spec.image_token!r} " + f"(id={spec.image_token_id})" + ) + encode = functools.partial( + encode_streaming_multimodal, + tokenizer=tokenizer, + max_tokens=effective_seq_len, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + text_column=text_column, + image_column=image_column, + ) + else: + encode = functools.partial( + encode_streaming, + tokenizer=tokenizer, + max_tokens=effective_seq_len, + text_column=text_column, + concatenate=cfg.pretraining_sample_concatenation is True, + ) + if cfg.shuffle_merged_datasets: dataset = dataset.shuffle( seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index dbd73b66f6..b84284e1f2 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -22,6 +22,7 @@ DatasetConfig, DPODataset, KTODataset, + MultiModalEvalDataset, PretrainingDataset, SFTDataset, StepwiseSupervisedDataset, @@ -353,7 +354,8 @@ class AxolotlInputConfig( test_datasets: ( Annotated[ list[ - SFTDataset + MultiModalEvalDataset + | SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index 97ed71631d..367133ac38 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -232,12 +232,70 @@ class PretrainingDataset(BaseModel): name: str | None = None path: str | None = None split: str | None = "train" - text_column: str | None = "text" + text_column: str | None = Field( + default="text", + json_schema_extra={"description": "Column holding the row's text."}, + ) type: str | None = "pretrain" trust_remote_code: bool | None = False - data_files: str | None = None + data_files: str | list[str] | None = None + ds_type: str | None = Field( + default=None, + json_schema_extra={ + "description": "Dataset loader type when `path` points to local files (e.g. 'json', 'csv', 'parquet')." + }, + ) skip: int | None = None + # Multimodal CPT fields. Opt-in via `type: multimodal_pretrain` or `multimodal: true`. + multimodal: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Opt in to multimodal CPT. Auto-enabled when type='multimodal_pretrain'." + }, + ) + image_column: str | None = Field( + default="images", + json_schema_extra={ + "description": "Column holding a list of image paths per row." + }, + ) + image_base_dir: str | None = Field( + default=None, + json_schema_extra={"description": "Base directory for relative image paths."}, + ) + image_token: str | None = Field( + default=None, + json_schema_extra={ + "description": "Override the image placeholder token (autodetected from processor if unset)." + }, + ) + + +class MultiModalEvalDataset(PretrainingDataset): + """Multimodal CPT eval dataset configuration (test_datasets entry). + + Inherits all fields from :class:`PretrainingDataset` and only overrides + the `type` default (no `"pretrain"` fallback for eval entries) plus adds + a validator that requires an explicit MM marker. + """ + + type: str | None = None + + @model_validator(mode="before") + @classmethod + def _require_mm_markers(cls, data): + if isinstance(data, BaseModel): + data = data.model_dump() + if not isinstance(data, dict): + return data + if data.get("type") != "multimodal_pretrain" and not data.get("multimodal"): + raise ValueError( + "MultiModalEvalDataset requires type='multimodal_pretrain' " + "or multimodal=True" + ) + return data + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 181ddaecbc..9b1ecc9409 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1383,6 +1383,117 @@ def check_streaming_w_multiple_datasets(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_multimodal_cpt(cls, data): + def _entry_is_mm(entry) -> bool: + if isinstance(entry, dict): + ds_type_ = entry.get("type") + mm_flag_ = entry.get("multimodal") + else: + ds_type_ = getattr(entry, "type", None) + mm_flag_ = getattr(entry, "multimodal", None) + return ds_type_ == "multimodal_pretrain" or bool(mm_flag_) + + pd = data.get("pretraining_dataset") + pd_list = pd if isinstance(pd, list) else ([pd] if pd else []) + train_is_mm = ( + bool(pd_list) and isinstance(pd_list[0], dict) and _entry_is_mm(pd_list[0]) + ) + + test_datasets = data.get("test_datasets") or [] + test_dicts = [t for t in test_datasets if isinstance(t, dict)] + mm_flags = [_entry_is_mm(t) for t in test_dicts] + if mm_flags: + if any(mm_flags) != all(mm_flags): + raise ValueError( + "Mixing multimodal and non-multimodal entries in " + "`test_datasets` is not supported. All eval entries " + "must share modality." + ) + if all(mm_flags) and not train_is_mm: + raise ValueError( + "Multimodal `test_datasets` require multimodal CPT " + "training (set `pretraining_dataset[0].type` to " + "'multimodal_pretrain' or `multimodal: true`)." + ) + if not any(mm_flags) and train_is_mm: + raise ValueError( + "Multimodal CPT training requires multimodal " + "`test_datasets` entries (type='multimodal_pretrain' " + "or multimodal: true)." + ) + + if not pd_list: + return data + + # MM config resolves from entry[0] only; multi-entry runs miscollate or silently demote. + if len(pd_list) > 1 and any(_entry_is_mm(e) for e in pd_list): + raise ValueError( + "Multimodal CPT supports exactly one `pretraining_dataset` " + f"entry (found {len(pd_list)}). Image settings " + "(`image_base_dir`, `image_token`) and MM-mode detection " + "both resolve from entry[0] only, so additional entries " + "would be silently miscollated or drop their MM config. " + "Split multimodal CPT into its own run." + ) + + first = pd_list[0] + if not isinstance(first, dict): + return data + + if not train_is_mm: + return data + + if not data.get("processor_type"): + raise ValueError( + "Multimodal CPT (type: multimodal_pretrain) requires " + "`processor_type` to be set — e.g. `processor_type: AutoProcessor`. " + "Without a processor, images in the dataset cannot be turned " + "into pixel tensors." + ) + if data.get("sample_packing"): + raise ValueError( + "Multimodal CPT is incompatible with `sample_packing: true`. " + "Each image's placeholder token expands to a variable number " + "of patch tokens at the processor, so cross-row packing would " + "break the 1-to-1 alignment between text placeholders and " + "pixel_values. Set `sample_packing: false`." + ) + if data.get("chat_template"): + raise ValueError( + "Multimodal CPT (raw image+text pretraining) is incompatible " + "with `chat_template`. The point of the CPT path is to avoid " + "conversational scaffolding entirely. Remove `chat_template` " + "or switch to chat-template SFT." + ) + # Keep `images` and `_mm_text` columns alive for the collator. + prev_remove_unused = data.get("remove_unused_columns") + if prev_remove_unused is not False: + LOG.info( + "Auto-set `remove_unused_columns: false` for multimodal CPT " + "to preserve `images` and `_mm_text` columns (previous value: %r)", + prev_remove_unused, + ) + data["remove_unused_columns"] = False + + test_datasets = data.get("test_datasets") or [] + mm_test = [t for t in test_datasets if isinstance(t, dict) and _entry_is_mm(t)] + if len(mm_test) > 1: + for key in ("image_base_dir", "image_token"): + values = {t.get(key) for t in mm_test} + if len(values) > 1: + raise ValueError( + f"Multimodal CPT eval requires `{key}` to be either " + f"unset on all `test_datasets` entries or identical " + f"across them. The eval collator resolves " + f"`image_base_dir` and `image_token` once from the " + f"first entry, so heterogeneous values would silently " + f"miscollate later entries. Got: {sorted(map(str, values))}." + ) + + return data + class ModelCompatibilityValidationMixin: """Validation methods for specific model compatibility.""" diff --git a/tests/conftest.py b/tests/conftest.py index 16a01f8aa3..25912ced4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -112,6 +112,24 @@ def download_smollm2_135m_instruct_model(): snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M-Instruct", repo_type="model") +@pytest.fixture(scope="session", autouse=True) +def download_smolvlm_500m_instruct_model(): + # Processor/tokenizer only — skip ~1 GB of weight shards. + snapshot_download_w_retry( + "HuggingFaceTB/SmolVLM-500M-Instruct", + repo_type="model", + allow_patterns=[ + "*.json", + "*.txt", + "*.model", + "*.jinja", + "tokenizer*", + "vocab*", + "merges*", + ], + ) + + @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_gptq_model(): # download the model diff --git a/tests/prompt_strategies/test_multimodal_pretrain.py b/tests/prompt_strategies/test_multimodal_pretrain.py new file mode 100644 index 0000000000..da6e6a4985 --- /dev/null +++ b/tests/prompt_strategies/test_multimodal_pretrain.py @@ -0,0 +1,115 @@ +"""Multimodal CPT helpers + safety gate tests. + +The non-streaming strategy class and ``load()`` factory are deferred to a +follow-on PR (along with the matching ``build_collator`` routing for +``datasets:`` MM CPT batches), so only the helper-level surface is exercised +here in v1. +""" + +from __future__ import annotations + +import pytest +from transformers import AutoProcessor + +from axolotl.prompt_strategies.multimodal_pretrain import ( + _INCOMPATIBLE_PROCESSOR_REASONS, + ImageTokenSpec, + build_image_token_spec, + check_processor_compatibility, +) + +from tests.hf_offline_utils import enable_hf_offline + +_SMOLVLM = "HuggingFaceTB/SmolVLM-500M-Instruct" + + +@pytest.fixture(scope="module", name="smolvlm_processor") +@enable_hf_offline +def fixture_smolvlm_processor( + download_smolvlm_500m_instruct_model, # pylint: disable=unused-argument +): + return AutoProcessor.from_pretrained(_SMOLVLM) + + +# ---- build_image_token_spec ------------------------------------------------ + + +def test_build_image_token_spec_autodetects_smolvlm(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + assert isinstance(spec, ImageTokenSpec) + assert spec.image_token == "" + assert spec.image_token_id > 0 + assert spec.image_token_id in spec.image_family_token_ids + + +def test_build_image_token_spec_honors_override(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor, override="") + assert spec.image_token == "" + + +def test_build_image_token_spec_rejects_bad_override(smolvlm_processor): + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(smolvlm_processor, override="") + + +def test_build_image_token_spec_rejects_plain_word_override(smolvlm_processor): + # Plain words BPE-tokenize but aren't placeholders. + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(smolvlm_processor, override="image") + + +def test_build_image_token_spec_keeps_image_token_when_no_soft_token_in_name( + smolvlm_processor, +): + """Non-Gemma-3 processors: the boi-swap heuristic only fires when + `image_token` name contains "soft_token" (Gemma-3 convention). Otherwise + `image_token` IS the user-facing placeholder (Gemma-4 convention) and + must not be silently replaced by `boi_token`.""" + tok = smolvlm_processor.tokenizer + image_id = tok.convert_tokens_to_ids("") + boi_id = tok.convert_tokens_to_ids("") + assert boi_id != image_id, ( + "fixture assumption broken: SmolVLM tokenizer should map these to distinct ids" + ) + + class _FakeGemma4Like: + image_token = "" # no 'soft_token' in name → must not swap + boi_token = "" + tokenizer = tok + + spec = build_image_token_spec(_FakeGemma4Like()) + assert spec.image_token == "" + assert spec.image_token_id == image_id + assert spec.image_token_id != boi_id + + +# ---- check_processor_compatibility (startup-time gate) --------------------- + + +@pytest.mark.parametrize("cls_name", list(_INCOMPATIBLE_PROCESSOR_REASONS.keys())) +def test_check_processor_compatibility_rejects_incompatible(cls_name): + fake = type(cls_name, (), {})() + with pytest.raises(ValueError) as exc: + check_processor_compatibility(fake) + assert cls_name in str(exc.value) + assert _INCOMPATIBLE_PROCESSOR_REASONS[cls_name] in str(exc.value) + + +def test_check_processor_compatibility_rejects_subclass(): + # MRO-name fallback must catch user-defined subclasses. + class BaseMllama: + pass + + BaseMllama.__name__ = "MllamaProcessor" + + class CustomUserProcessor(BaseMllama): + pass + + CustomUserProcessor.__name__ = "CustomUserProcessor" + + with pytest.raises(ValueError, match="MllamaProcessor"): + check_processor_compatibility(CustomUserProcessor()) + + +def test_check_processor_compatibility_accepts_supported(smolvlm_processor): + check_processor_compatibility(smolvlm_processor) diff --git a/tests/test_multimodal_streaming.py b/tests/test_multimodal_streaming.py new file mode 100644 index 0000000000..dca32d66b8 --- /dev/null +++ b/tests/test_multimodal_streaming.py @@ -0,0 +1,612 @@ +"""Multimodal CPT streaming encoder + collator tests.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import AutoProcessor + +from axolotl.prompt_strategies.multimodal_pretrain import build_image_token_spec +from axolotl.utils.collators.mm_pretrain import MultiModalPretrainDataCollator +from axolotl.utils.data.streaming import ( + encode_streaming_multimodal, + wrap_streaming_dataset, +) +from axolotl.utils.dict import DictDefault + +from tests.hf_offline_utils import enable_hf_offline + +_SMOLVLM = "HuggingFaceTB/SmolVLM-500M-Instruct" + + +@pytest.fixture(scope="module", name="smolvlm_processor") +@enable_hf_offline +def fixture_smolvlm_processor( + download_smolvlm_500m_instruct_model, # pylint: disable=unused-argument +): + return AutoProcessor.from_pretrained(_SMOLVLM) + + +@pytest.fixture(scope="module", name="two_tiny_images") +def fixture_two_tiny_images(tmp_path_factory) -> list[Path]: + d = tmp_path_factory.mktemp("mm_stream_imgs") + out = [] + for i in range(2): + p = d / f"dummy_{i}.png" + arr = np.random.default_rng(i).integers(0, 255, (64, 64, 3)).astype("uint8") + Image.fromarray(arr).save(p) + out.append(p) + return out + + +# ---- encode_streaming_multimodal ------------------------------------------ + + +def test_encode_preserves_images_and_text(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + examples = { + "text": [ + f"{spec.image_token}\nrow one", + f"{spec.image_token}\nrow two slightly longer", + ], + "images": [[str(two_tiny_images[0])], [str(two_tiny_images[1])]], + } + out = encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + assert set(out) >= {"input_ids", "labels", "attention_mask", "images", "_mm_text"} + assert len(out["input_ids"]) == 2 + assert out["images"] == [[str(two_tiny_images[0])], [str(two_tiny_images[1])]] + # EOS appended -> input_ids len equals attention_mask len and > text + for ids, mask in zip(out["input_ids"], out["attention_mask"], strict=True): + assert len(ids) == len(mask) and len(ids) > 0 + # CPT: labels == input_ids pre-masking. + for ids, lbls in zip(out["input_ids"], out["labels"], strict=True): + assert ids == lbls + + +def test_encode_rejects_mismatch(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + examples = { + "text": [f"{spec.image_token}{spec.image_token}\ntwo placeholders one image"], + "images": [[str(two_tiny_images[0])]], + } + with pytest.raises(ValueError, match="occurrence"): + encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +def test_encode_rejects_row_without_list(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + with pytest.raises(ValueError, match="list"): + encode_streaming_multimodal( + { + "text": [f"{spec.image_token}\nrow one"], + "images": [str(two_tiny_images[0])], # scalar, not a list + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +def test_encode_counts_placeholders_on_full_text(smolvlm_processor, two_tiny_images): + # The last placeholder must remain countable even when it's hundreds of + # tokens deep — guards against a regression that adds tokenizer + # truncation and silently drops trailing placeholders. + spec = build_image_token_spec(smolvlm_processor) + long_filler = "lorem ipsum " * 400 + text = f"{spec.image_token} {long_filler} {spec.image_token} {long_filler} {spec.image_token}" + examples = { + "text": [text], + "images": [[str(two_tiny_images[0])] * 3], + } + out = encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=4096, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + ids = out["input_ids"][0] + # Sanity: the input is genuinely long, so a truncating regression would + # have to cut into it to drop the last placeholder. + assert len(ids) > 2000 + assert sum(1 for t in ids if t == spec.image_token_id) == 3 + + +def test_encode_rejects_row_exceeding_max_tokens(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + huge = "word " * 5000 + examples = { + "text": [f"{spec.image_token} {huge}"], + "images": [[str(two_tiny_images[0])]], + } + with pytest.raises(ValueError, match="exceeds sequence_len"): + encode_streaming_multimodal( + examples, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=512, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + + +# ---- build_image_token_spec autodetection -------------------------------- + + +class _StubTokenizer: + """Minimal tokenizer stub for autodetection tests.""" + + def __init__(self, vocab: dict[str, int], unk_id: int = 0): + self._vocab = vocab + self.unk_token_id = unk_id + self.all_special_tokens = list(vocab.keys()) + self.additional_special_tokens: list[str] = [] + + def get_added_vocab(self): + return dict(self._vocab) + + def convert_tokens_to_ids(self, tok): + return self._vocab.get(tok, self.unk_token_id) + + +class _StubProcessor: + def __init__(self, tokenizer, image_token=None, boi_token=None): + self.tokenizer = tokenizer + if image_token is not None: + self.image_token = image_token + if boi_token is not None: + self.boi_token = boi_token + + +def test_build_image_token_spec_gemma4_uses_image_token_not_boi(): + """Gemma-4: `image_token` is the user-facing placeholder; don't swap to boi_token.""" + tok = _StubTokenizer({"<|image|>": 258880, "<|image>": 255999}) + proc = _StubProcessor(tok, image_token="<|image|>", boi_token="<|image>") + spec = build_image_token_spec(proc) + assert spec.image_token == "<|image|>" + assert spec.image_token_id == 258880 + + +def test_build_image_token_spec_gemma3_swaps_to_boi_token(): + """Gemma-3: `image_token` is the post-expansion soft token; placeholder is `boi_token`.""" + tok = _StubTokenizer({"": 262144, "": 255999}) + proc = _StubProcessor( + tok, image_token="", boi_token="" + ) + spec = build_image_token_spec(proc) + assert spec.image_token == "" + assert spec.image_token_id == 255999 + + +def test_build_image_token_spec_override_not_special_rejected(): + """Override that isn't a registered special token is rejected (would BPE-tokenize).""" + tok = _StubTokenizer({"<|image|>": 258880}) + proc = _StubProcessor(tok, image_token="<|image|>") + with pytest.raises(ValueError, match="not a registered special token"): + build_image_token_spec(proc, override="not_a_real_token") + + +def test_build_image_token_spec_override_resolves_to_unk_rejected(): + """Override that resolves to unk is rejected with a clear error.""" + tok = _StubTokenizer({"<|image|>": 258880, "<|fake|>": 0}, unk_id=0) + proc = _StubProcessor(tok, image_token="<|image|>") + with pytest.raises(ValueError, match="did not resolve"): + build_image_token_spec(proc, override="<|fake|>") + + +def test_build_image_token_spec_no_candidates_raises(): + """If neither processor attrs nor any known candidate resolve, raise a clear error.""" + tok = _StubTokenizer({}) # nothing registered + proc = _StubProcessor(tok) # no image_token, no boi_token + with pytest.raises(ValueError, match="Could not autodetect"): + build_image_token_spec(proc) + + +# ---- wrap_streaming_dataset routing -------------------------------------- + + +def test_wrap_streaming_dataset_uses_pretraining_config_arg( + smolvlm_processor, monkeypatch +): + # Eval path passes a per-entry config that may differ from cfg.pretraining_dataset[0]. + # The MM-CPT branch must read from that arg, not re-resolve from cfg. + captured = {} + + def fake_partial(fn, **kwargs): + captured["encode_fn"] = fn + captured["kwargs"] = kwargs + return lambda batch: batch + + monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", fake_partial) + + class _Dataset: + features = {"text": None, "images": None} + + def shuffle(self, **_): + return self + + def map(self, *_args, **_kwargs): + return self + + cfg = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + { + "path": "train/ds", + "type": "multimodal_pretrain", + "text_column": "wrong_train_col", + "image_column": "wrong_train_imgs", + } + ], + "sequence_len": 256, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + eval_entry = DictDefault( + { + "path": "test/ds", + "type": "multimodal_pretrain", + "text_column": "eval_text", + "image_column": "eval_imgs", + } + ) + + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=eval_entry, + ) + + assert captured["encode_fn"] is encode_streaming_multimodal + assert captured["kwargs"]["text_column"] == "eval_text" + assert captured["kwargs"]["image_column"] == "eval_imgs" + + +def test_wrap_streaming_dataset_eval_honors_eval_sequence_len( + smolvlm_processor, monkeypatch +): + """is_eval=True with cfg.eval_sequence_len set caps encoder at eval_sequence_len.""" + captured = {} + + def fake_partial(fn, **kwargs): + captured["encode_fn"] = fn + captured["kwargs"] = kwargs + return lambda batch: batch + + monkeypatch.setattr("axolotl.utils.data.streaming.functools.partial", fake_partial) + + class _Dataset: + features = {"text": None, "images": None} + + def shuffle(self, **_): + return self + + def map(self, *_args, **_kwargs): + return self + + cfg = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "sequence_len": 4096, + "eval_sequence_len": 1024, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "test/ds", "type": "multimodal_pretrain"} + ), + is_eval=True, + ) + assert captured["kwargs"]["max_tokens"] == 1024 + + captured.clear() + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "train/ds", "type": "multimodal_pretrain"} + ), + is_eval=False, + ) + assert captured["kwargs"]["max_tokens"] == 4096 + + # eval_sequence_len unset -> eval falls back to sequence_len. + captured.clear() + cfg_no_eval = DictDefault( + { + "sample_packing": False, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "sequence_len": 4096, + "shuffle_merged_datasets": False, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + wrap_streaming_dataset( + _Dataset(), + smolvlm_processor.tokenizer, + cfg_no_eval, + ds_wrapper_fn=None, + processor=smolvlm_processor, + pretraining_config=DictDefault( + {"path": "test/ds", "type": "multimodal_pretrain"} + ), + is_eval=True, + ) + assert captured["kwargs"]["max_tokens"] == 4096 + + +# ---- MultiModalPretrainDataCollator --------------------------------------- + + +def test_collator_builds_batch_and_masks_labels(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + encoded = encode_streaming_multimodal( + { + "text": [ + f"{spec.image_token}\nrow one", + f"{spec.image_token}\nrow two slightly longer", + ], + "images": [[str(two_tiny_images[0])], [str(two_tiny_images[1])]], + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + rows = [ + { + k: encoded[k][i] + for k in ("input_ids", "labels", "attention_mask", "images", "_mm_text") + } + for i in range(2) + ] + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + batch = collator.torch_call(rows) + # Expected keys + for k in ("input_ids", "attention_mask", "pixel_values", "labels"): + assert k in batch, f"missing batch key {k}" + assert isinstance(batch["input_ids"], torch.Tensor) + # Label masking check: no image-family ids remaining as valid labels. + for tid in spec.image_family_token_ids: + assert int((batch["labels"] == tid).sum().item()) == 0, ( + f"label masking left id={tid} in labels" + ) + # Pad is also masked. + pad_id = smolvlm_processor.tokenizer.pad_token_id + if pad_id is not None: + assert int((batch["labels"] == pad_id).sum().item()) == 0 + + +def test_collator_raises_on_missing_columns(smolvlm_processor): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + with pytest.raises(KeyError, match="encode_streaming_multimodal"): + collator.torch_call([{"input_ids": [1, 2, 3]}]) # no _mm_text / images + + +# ---- input validation ----------------------------------------------------- + + +def test_collator_rejects_bytes_mm_text(smolvlm_processor, two_tiny_images): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\nrow".encode(), + "images": [str(two_tiny_images[0])], + } + ] + with pytest.raises(TypeError, match="`_mm_text` must be str"): + collator.torch_call(rows) + + +def test_collator_sanitizes_error_message(smolvlm_processor, tmp_path): + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + missing = tmp_path / "subdir_with_secret_name" / "nope.png" + with pytest.raises(RuntimeError) as exc: + collator._load_images_for_row([str(missing)], row_index=3) + # basename appears, full directory path does NOT + assert "nope.png" in str(exc.value) + assert "subdir_with_secret_name" not in str(exc.value) + assert "Row 3" in str(exc.value) + + +def test_collator_skip_bad_images_drops_row_and_continues( + smolvlm_processor, two_tiny_images, tmp_path +): + """skip_bad_images=True: bad row drops, batch survives on remaining rows.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + skip_bad_images=True, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\ngood row", + "images": [str(two_tiny_images[0])], + }, + { + "_mm_text": f"{spec.image_token}\nbad row", + "images": [str(tmp_path / "missing.png")], + }, + ] + batch = collator.torch_call(rows) + # Surviving row produced a batch with pixel_values from the good image. + assert "input_ids" in batch and "pixel_values" in batch + assert batch["input_ids"].shape[0] == 1 + + +def test_collator_all_rows_dropped_raises(smolvlm_processor, tmp_path): + """skip_bad_images=True with every row failing surfaces a RuntimeError.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + skip_bad_images=True, + ) + rows = [ + { + "_mm_text": f"{spec.image_token}\nrow", + "images": [str(tmp_path / f"missing_{i}.png")], + } + for i in range(2) + ] + with pytest.raises(RuntimeError, match="All rows in the batch were dropped"): + collator.torch_call(rows) + + +# ---- mixed / all-text batches -------------------------------------------- + + +def test_collator_warns_when_tokenizer_diverges_from_processor_tokenizer( + smolvlm_processor, caplog, monkeypatch +): + """Construct-time warning when self.tokenizer is not processor.tokenizer.""" + import logging as _logging + + # `axolotl` logger has propagate=False (logging_config.py); flip it so + # caplog's root handler receives the record. + monkeypatch.setattr(_logging.getLogger("axolotl"), "propagate", True) + spec = build_image_token_spec(smolvlm_processor) + + # Same tokenizer: no warning. + with caplog.at_level( + _logging.WARNING, logger="axolotl.utils.collators.mm_pretrain" + ): + MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + assert not any("tokenize inconsistently" in r.getMessage() for r in caplog.records) + + caplog.clear() + + # Different tokenizer instance (a stand-in object): warning fires. + class _OtherTokenizer: + pad_token_id = None + + with caplog.at_level( + _logging.WARNING, logger="axolotl.utils.collators.mm_pretrain" + ): + MultiModalPretrainDataCollator( + tokenizer=_OtherTokenizer(), + processor=smolvlm_processor, + image_token_spec=spec, + ) + assert any("tokenize inconsistently" in r.getMessage() for r in caplog.records) + + +def test_collator_all_text_batch_uses_tokenizer_fallback(smolvlm_processor): + """A batch where every row has images=[] tokenizes via the tokenizer; no pixel_values.""" + spec = build_image_token_spec(smolvlm_processor) + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + rows = [ + {"_mm_text": "first text-only row", "images": []}, + {"_mm_text": "second text-only row, slightly longer", "images": []}, + ] + batch = collator.torch_call(rows) + for k in ("input_ids", "attention_mask", "labels"): + assert k in batch, f"missing batch key {k}" + assert "pixel_values" not in batch + assert isinstance(batch["input_ids"], torch.Tensor) + pad_id = smolvlm_processor.tokenizer.pad_token_id + if pad_id is not None: + assert int((batch["labels"] == pad_id).sum().item()) == 0 + + +def test_collator_mixed_batch_still_succeeds(smolvlm_processor, two_tiny_images): + """A batch with one imaged row and one text-only row still produces pixel_values.""" + spec = build_image_token_spec(smolvlm_processor) + encoded = encode_streaming_multimodal( + { + "text": [ + f"{spec.image_token}\nimaged row", + "text-only row", + ], + "images": [[str(two_tiny_images[0])], []], + }, + tokenizer=smolvlm_processor.tokenizer, + max_tokens=2048, + image_token=spec.image_token, + image_token_id=spec.image_token_id, + ) + rows = [ + { + k: encoded[k][i] + for k in ("input_ids", "labels", "attention_mask", "images", "_mm_text") + } + for i in range(2) + ] + collator = MultiModalPretrainDataCollator( + tokenizer=smolvlm_processor.tokenizer, + processor=smolvlm_processor, + image_token_spec=spec, + ) + batch = collator.torch_call(rows) + for k in ("input_ids", "attention_mask", "pixel_values", "labels"): + assert k in batch, f"missing batch key {k}" diff --git a/tests/utils/data/test_mm_cpt_eval.py b/tests/utils/data/test_mm_cpt_eval.py new file mode 100644 index 0000000000..73aa1015ec --- /dev/null +++ b/tests/utils/data/test_mm_cpt_eval.py @@ -0,0 +1,302 @@ +"""Multimodal CPT eval-path tests.""" + +from __future__ import annotations + +from axolotl.utils.data.sft import ( + _create_placeholder_dataset, + _prepare_streaming_dataset, +) +from axolotl.utils.dict import DictDefault + +# ---- placeholder dataset for dispatch_batches ---------------------------- + + +def test_placeholder_text_only_keeps_existing_shape(): + """Without an MM config, the placeholder is a single-column text dataset.""" + ds = _create_placeholder_dataset() + row = next(iter(ds)) + assert "text" in row + assert "images" not in row + + +def test_placeholder_mm_emits_image_column(): + """MM placeholder rows carry the configured image column as an empty list.""" + pt_cfg = DictDefault( + { + "type": "multimodal_pretrain", + "text_column": "text", + "image_column": "images", + "multimodal": True, + } + ) + ds = _create_placeholder_dataset(pt_cfg) + row = next(iter(ds)) + assert "text" in row + assert "images" in row + assert row["images"] == [] + + +def test_placeholder_mm_honors_custom_columns(): + """Custom text_column / image_column on the MM config are reflected in the placeholder row.""" + pt_cfg = DictDefault( + { + "type": "multimodal_pretrain", + "text_column": "doc", + "image_column": "imgs", + } + ) + ds = _create_placeholder_dataset(pt_cfg) + row = next(iter(ds)) + assert "doc" in row + assert "imgs" in row + assert row["imgs"] == [] + + +def test_pretraining_config_from_entry_preserves_trust_remote_code(): + """trust_remote_code on the dataset entry survives normalization.""" + from axolotl.utils.data.sft import _pretraining_config_from_entry + + cfg = _pretraining_config_from_entry( + {"path": "ds", "type": "multimodal_pretrain", "trust_remote_code": True} + ) + assert cfg["trust_remote_code"] is True + + cfg = _pretraining_config_from_entry({"path": "ds", "type": "multimodal_pretrain"}) + assert cfg["trust_remote_code"] is False + + +def test_pretraining_config_from_entry_preserves_ds_type(): + """ds_type on the dataset entry survives normalization.""" + from axolotl.utils.data.sft import _pretraining_config_from_entry + + cfg = _pretraining_config_from_entry( + {"path": "/data/*.jsonl", "type": "multimodal_pretrain", "ds_type": "json"} + ) + assert cfg["ds_type"] == "json" + + cfg = _pretraining_config_from_entry({"path": "ds", "type": "multimodal_pretrain"}) + assert cfg["ds_type"] is None + + +def test_load_streaming_dataset_routes_ds_type_to_loader(monkeypatch): + """When ds_type is set, load_dataset is called with the loader name and + path becomes data_files.""" + from axolotl.utils.data.sft import _load_streaming_dataset + + captured = {} + + def fake_load_dataset(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + class _Stub: + def skip(self, *_a, **_kw): + return self + + return _Stub() + + def fake_wrap(ds, *_a, **_kw): + return ds + + class _StubFormat: + def with_format(self, *_a, **_kw): + return self + + monkeypatch.setattr("axolotl.utils.data.sft.load_dataset", fake_load_dataset) + monkeypatch.setattr( + "axolotl.utils.data.sft.wrap_streaming_dataset", + lambda *a, **kw: _StubFormat(), + ) + + pretraining_config = DictDefault( + { + "path": "/data/shards/*.jsonl", + "name": None, + "skip": 0, + "split": "train", + "data_files": None, + "ds_type": "json", + "type": "multimodal_pretrain", + "text_column": "text", + "multimodal": True, + "image_column": "images", + "image_base_dir": None, + "image_token": None, + "trust_remote_code": False, + } + ) + cfg = DictDefault({"sequence_len": 2048, "accelerator_config": None}) + + _load_streaming_dataset(pretraining_config, cfg, tokenizer=None, processor=None) + + assert captured["args"] == ("json",) + assert captured["kwargs"]["data_files"] == "/data/shards/*.jsonl" + assert captured["kwargs"]["split"] == "train" + + +# ---- multiple MM eval datasets are loaded -------------------------------- + + +def test_mm_eval_iterates_all_test_datasets(monkeypatch): + """All MM entries in test_datasets are loaded and concatenated into the eval stream.""" + cfg = DictDefault( + { + "streaming": True, + "pretraining_dataset": [ + {"path": "train/ds", "type": "multimodal_pretrain"} + ], + "test_datasets": [ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "multimodal_pretrain"}, + {"path": "eval/c", "type": "multimodal_pretrain"}, + ], + "max_steps": 10, + } + ) + + seen_eval_paths: list[str] = [] + + def fake_load_streaming(pretraining_config, *_a, **_kw): + path = pretraining_config["path"] + if path.startswith("eval/"): + seen_eval_paths.append(path) + return f"" + + def fake_concat(streams): + return tuple(streams) + + monkeypatch.setattr( + "axolotl.utils.data.sft._load_streaming_dataset", fake_load_streaming + ) + monkeypatch.setattr("axolotl.utils.data.sft.concatenate_datasets", fake_concat) + + _train, eval_ds, _, _ = _prepare_streaming_dataset( + cfg, tokenizer=None, processor=None + ) + + assert seen_eval_paths == ["eval/a", "eval/b", "eval/c"] + assert eval_ds == ("", "", "") + + +# Mixed MM / non-MM test_datasets is rejected at config-load time by +# check_multimodal_cpt (see tests/utils/schemas/validation/test_multimodal_cpt.py). + + +# ---- eval collator pulls image settings from test_datasets --------------- + + +def test_eval_collator_uses_eval_image_settings(monkeypatch): + """Eval collator pulls image_base_dir / image_token from test_datasets[0]; train collator from pretraining_dataset[0].""" + from axolotl.core.builders.causal import HFCausalTrainerBuilder + + captured = {} + + class _FakeSpec: + image_token = "" + image_token_id = 7 + image_family_token_ids = (7,) + + def fake_build_image_token_spec(processor, override=None): + captured["override"] = override + return _FakeSpec() + + monkeypatch.setattr( + "axolotl.prompt_strategies.multimodal_pretrain.build_image_token_spec", + fake_build_image_token_spec, + ) + + class _FakeCollator: + def __init__(self, **kw): + captured["kwargs"] = kw + + monkeypatch.setattr( + "axolotl.core.builders.causal.MultiModalPretrainDataCollator", _FakeCollator + ) + + builder = HFCausalTrainerBuilder.__new__(HFCausalTrainerBuilder) + builder.tokenizer = object() + builder.processor = object() + builder.cfg = DictDefault( + { + "pretraining_dataset": [ + { + "type": "multimodal_pretrain", + "image_base_dir": "/train_images", + "image_token": "", + } + ], + "test_datasets": [ + { + "type": "multimodal_pretrain", + "image_base_dir": "/eval_images", + "image_token": "", + } + ], + "sequence_len": 2048, + } + ) + + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["override"] == "" + assert captured["kwargs"]["image_base_dir"] == "/eval_images" + + captured.clear() + builder._build_mm_pretrain_collator(is_eval=False) + assert captured["override"] == "" + assert captured["kwargs"]["image_base_dir"] == "/train_images" + + +def test_eval_collator_honors_eval_sequence_len(monkeypatch): + """Eval collator uses cfg.eval_sequence_len when set; train collator uses cfg.sequence_len.""" + from axolotl.core.builders.causal import HFCausalTrainerBuilder + + captured = {} + + class _FakeSpec: + image_token = "" + image_token_id = 7 + image_family_token_ids = (7,) + + monkeypatch.setattr( + "axolotl.prompt_strategies.multimodal_pretrain.build_image_token_spec", + lambda processor, override=None: _FakeSpec(), + ) + + class _FakeCollator: + def __init__(self, **kw): + captured["kwargs"] = kw + + monkeypatch.setattr( + "axolotl.core.builders.causal.MultiModalPretrainDataCollator", _FakeCollator + ) + + builder = HFCausalTrainerBuilder.__new__(HFCausalTrainerBuilder) + builder.tokenizer = object() + builder.processor = object() + builder.cfg = DictDefault( + { + "pretraining_dataset": [{"type": "multimodal_pretrain"}], + "test_datasets": [{"type": "multimodal_pretrain"}], + "sequence_len": 4096, + "eval_sequence_len": 1024, + } + ) + + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["kwargs"]["max_length"] == 1024 + + captured.clear() + builder._build_mm_pretrain_collator(is_eval=False) + assert captured["kwargs"]["max_length"] == 4096 + + # eval_sequence_len unset -> eval falls back to sequence_len + builder.cfg = DictDefault( + { + "pretraining_dataset": [{"type": "multimodal_pretrain"}], + "test_datasets": [{"type": "multimodal_pretrain"}], + "sequence_len": 4096, + } + ) + captured.clear() + builder._build_mm_pretrain_collator(is_eval=True) + assert captured["kwargs"]["max_length"] == 4096 diff --git a/tests/utils/schemas/validation/test_multimodal_cpt.py b/tests/utils/schemas/validation/test_multimodal_cpt.py new file mode 100644 index 0000000000..f110d6c7e1 --- /dev/null +++ b/tests/utils/schemas/validation/test_multimodal_cpt.py @@ -0,0 +1,336 @@ +"""Multimodal CPT config validation gates.""" + +from __future__ import annotations + +import logging + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +def _mm_cpt_cfg(min_base_cfg, **overrides) -> DictDefault: + base = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [ + { + "path": "some/ds", + "type": "multimodal_pretrain", + "image_column": "images", + } + ], + "streaming": True, + "max_steps": 10, + "processor_type": "AutoProcessor", + "sequence_len": 2048, + } + ) + ) + return base | DictDefault(overrides) + + +class TestMultimodalCPTGates: + def test_missing_processor_type_raises(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pop("processor_type", None) + with pytest.raises(ValueError, match="processor_type"): + validate_config(cfg) + + def test_sample_packing_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg, sample_packing=True) + with pytest.raises(ValueError, match="sample_packing"): + validate_config(cfg) + + def test_chat_template_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg, chat_template="tokenizer_default") + with pytest.raises(ValueError, match="chat_template"): + validate_config(cfg) + + def test_multiple_pretraining_dataset_entries_rejected(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pretraining_dataset.append({"path": "other/ds", "type": "pretrain"}) + with pytest.raises(ValueError, match="exactly one `pretraining_dataset`"): + validate_config(cfg) + + def test_multimodal_entry_in_non_first_slot_rejected(self, min_base_cfg): + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [ + {"path": "text/ds", "type": "pretrain"}, + { + "path": "mm/ds", + "type": "multimodal_pretrain", + "image_column": "images", + }, + ], + "streaming": True, + "max_steps": 10, + "processor_type": "AutoProcessor", + "sequence_len": 2048, + } + ) + ) + with pytest.raises(ValueError, match="exactly one `pretraining_dataset`"): + validate_config(cfg) + + def test_valid_cfg_passes_and_disables_remove_unused_columns(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + validated = validate_config(cfg) + assert validated.remove_unused_columns is False + pd = validated.pretraining_dataset[0] + assert pd.type == "multimodal_pretrain" + assert pd.image_column == "images" + + def test_multimodal_flag_triggers_gates(self, min_base_cfg): + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pretraining_dataset[0]["type"] = "pretrain" + cfg.pretraining_dataset[0]["multimodal"] = True + cfg.pop("processor_type", None) + with pytest.raises(ValueError, match="processor_type"): + validate_config(cfg) + + def test_non_mm_pretraining_dataset_unaffected(self, min_base_cfg): + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [{"path": "some/ds", "type": "pretrain"}], + "streaming": True, + "max_steps": 10, + "sequence_len": 2048, + } + ) + ) + validate_config(cfg) + + def test_mm_eval_dataset_keys_preserved_through_validation(self, min_base_cfg): + """MM-specific keys on a test_datasets entry survive validate_config.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/ds", + "type": "multimodal_pretrain", + "text_column": "eval_text", + "image_column": "eval_imgs", + "image_base_dir": "/eval/images", + "image_token": "", + } + ], + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert td["text_column"] == "eval_text" + assert td["image_column"] == "eval_imgs" + assert td["image_base_dir"] == "/eval/images" + assert td["image_token"] == "" + + def test_mm_eval_dataset_via_multimodal_flag(self, min_base_cfg): + """`multimodal: true` (without type='multimodal_pretrain') opts an eval entry into MM.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/ds", + "multimodal": True, + "image_column": "imgs2", + } + ], + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert td["image_column"] == "imgs2" + assert td["multimodal"] is True + + def test_non_mm_eval_entry_does_not_match_mm_model(self, min_base_cfg): + """SFT eval entries (no MM markers) still validate as SFTDataset.""" + cfg = DictDefault( + **( + min_base_cfg + | { + "test_datasets": [ + {"path": "eval/ds", "type": "alpaca", "split": "test"} + ], + "sequence_len": 2048, + } + ) + ) + validated = validate_config(cfg) + td = validated.test_datasets[0] + assert "message_property_mappings" in td + assert td["type"] == "alpaca" + + def test_mm_eval_rejects_mismatched_image_base_dir(self, min_base_cfg): + """Multiple MM eval entries with different image_base_dir are rejected.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_base_dir": "/images/a", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_base_dir": "/images/b", + }, + ], + ) + with pytest.raises(ValueError, match="image_base_dir"): + validate_config(cfg) + + def test_mm_eval_rejects_mismatched_image_token(self, min_base_cfg): + """Multiple MM eval entries with different image_token overrides are rejected.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_token": "", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_token": "", + }, + ], + ) + with pytest.raises(ValueError, match="image_token"): + validate_config(cfg) + + def test_mm_eval_accepts_matching_image_base_dir(self, min_base_cfg): + """Multiple MM eval entries sharing image_base_dir validate cleanly.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + { + "path": "eval/a", + "type": "multimodal_pretrain", + "image_base_dir": "/images/shared", + }, + { + "path": "eval/b", + "type": "multimodal_pretrain", + "image_base_dir": "/images/shared", + }, + ], + ) + validated = validate_config(cfg) + assert len(validated.test_datasets) == 2 + + def test_mm_eval_accepts_all_unset_image_settings(self, min_base_cfg): + """Multiple MM eval entries with image_base_dir / image_token unset everywhere validate.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "multimodal_pretrain"}, + ], + ) + validated = validate_config(cfg) + assert len(validated.test_datasets) == 2 + + def test_mixed_modality_test_datasets_rejected_at_validation(self, min_base_cfg): + """A test_datasets list mixing MM and non-MM entries fails at config-load.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[ + {"path": "eval/a", "type": "multimodal_pretrain"}, + {"path": "eval/b", "type": "alpaca", "split": "test"}, + ], + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Mixing multimodal and non-multimodal" in msg + assert "test_datasets" in msg + assert "share modality" in msg + + def test_mm_test_datasets_with_text_training_rejected(self, min_base_cfg): + """MM test_datasets paired with non-MM training fails at config-load.""" + cfg = DictDefault( + **( + min_base_cfg + | { + "datasets": None, + "pretraining_dataset": [{"path": "text/ds", "type": "pretrain"}], + "test_datasets": [ + {"path": "eval/a", "type": "multimodal_pretrain"} + ], + "streaming": True, + "max_steps": 10, + "sequence_len": 2048, + "processor_type": "AutoProcessor", + } + ) + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Multimodal `test_datasets`" in msg + assert "multimodal CPT training" in msg + assert "multimodal_pretrain" in msg + + def test_text_test_datasets_with_mm_training_rejected(self, min_base_cfg): + """Non-MM test_datasets paired with MM training fails at config-load.""" + cfg = _mm_cpt_cfg( + min_base_cfg, + test_datasets=[{"path": "eval/a", "type": "alpaca", "split": "test"}], + ) + with pytest.raises(ValueError) as exc: + validate_config(cfg) + msg = str(exc.value) + assert "Multimodal CPT training" in msg + assert "multimodal `test_datasets`" in msg + assert "multimodal_pretrain" in msg + + def test_remove_unused_columns_auto_set_emits_info_log( + self, min_base_cfg, caplog, monkeypatch + ): + """Auto-setting `remove_unused_columns: false` for MM CPT logs an INFO record naming the previous value.""" + # `axolotl` logger has propagate=False (logging_config.py); flip it so + # caplog's root handler receives the record. + monkeypatch.setattr(logging.getLogger("axolotl"), "propagate", True) + cfg = _mm_cpt_cfg(min_base_cfg) + cfg.pop("remove_unused_columns", None) + with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): + validated = validate_config(cfg) + assert validated.remove_unused_columns is False + matches = [ + r + for r in caplog.records + if r.levelno == logging.INFO and "Auto-set" in r.getMessage() + ] + assert matches, ( + "expected an INFO record about auto-setting remove_unused_columns" + ) + msg = matches[0].getMessage() + assert "remove_unused_columns" in msg + assert "previous value: None" in msg + + def test_remove_unused_columns_already_false_does_not_log( + self, min_base_cfg, caplog, monkeypatch + ): + """When the user already set `remove_unused_columns: false`, no auto-set log fires.""" + # Mirror the positive test: flip propagate on the parent `axolotl` + # logger (the one with propagate=False), not the leaf — otherwise + # caplog never sees axolotl.* records and this negative assertion + # is vacuous (it would pass even if the auto-set log fired). + monkeypatch.setattr(logging.getLogger("axolotl"), "propagate", True) + cfg = _mm_cpt_cfg(min_base_cfg, remove_unused_columns=False) + with caplog.at_level(logging.INFO, logger="axolotl.utils.schemas.validation"): + validate_config(cfg) + assert not any( + "Auto-set" in r.getMessage() and "remove_unused_columns" in r.getMessage() + for r in caplog.records + )