From 168b5dd4dbb9a11b7668775a44ce0c0710355e7e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 30 Oct 2024 17:03:10 +0100 Subject: [PATCH 1/3] First prototype --- .pre-commit-config.yaml | 0 torchtitan/datasets/__init__.py | 5 +- torchtitan/datasets/hf_datasets.py | 12 +- torchtitan/datasets/mm_datasets.py | 144 ++++++ torchtitan/datasets/multimodal/__init__.py | 21 + torchtitan/datasets/multimodal/clip.py | 195 ++++++++ torchtitan/datasets/multimodal/collator.py | 289 +++++++++++ .../datasets/multimodal/llama3_transform.py | 100 ++++ torchtitan/datasets/multimodal/utils.py | 458 ++++++++++++++++++ .../multimodal/vision_attention_mask.py | 171 +++++++ torchtitan/datasets/tokenizer/tiktoken.py | 32 ++ 11 files changed, 1423 insertions(+), 4 deletions(-) mode change 100644 => 100755 .pre-commit-config.yaml mode change 100644 => 100755 torchtitan/datasets/__init__.py mode change 100644 => 100755 torchtitan/datasets/hf_datasets.py create mode 100755 torchtitan/datasets/mm_datasets.py create mode 100755 torchtitan/datasets/multimodal/__init__.py create mode 100755 torchtitan/datasets/multimodal/clip.py create mode 100755 torchtitan/datasets/multimodal/collator.py create mode 100755 torchtitan/datasets/multimodal/llama3_transform.py create mode 100755 torchtitan/datasets/multimodal/utils.py create mode 100755 torchtitan/datasets/multimodal/vision_attention_mask.py mode change 100644 => 100755 torchtitan/datasets/tokenizer/tiktoken.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100644 new mode 100755 diff --git a/torchtitan/datasets/__init__.py b/torchtitan/datasets/__init__.py old mode 100644 new mode 100755 index 75ea6b66..16db876e --- a/torchtitan/datasets/__init__.py +++ b/torchtitan/datasets/__init__.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.datasets.hf_datasets import build_hf_data_loader +from torchtitan.datasets.hf_datasets import build_hf_data_loader, DPAwareDataLoader +from torchtitan.datasets.mm_datasets import build_mm_data_loader from torchtitan.datasets.tokenizer import build_tokenizer __all__ = [ "build_hf_data_loader", "build_tokenizer", + "DPAwareDataLoader", + "build_mm_data_loader", ] diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py old mode 100644 new mode 100755 index 9db036b0..62f2fa09 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import pickle -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.distributed.checkpoint.stateful import Stateful @@ -156,8 +156,14 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful): A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. """ - def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): - super().__init__(hf_ds, batch_size) + def __init__( + self, + dp_rank: int, + hf_ds: IterableDataset, + batch_size: int, + collator_fn: Callable, + ): + super().__init__(dataset=hf_ds, batch_size=batch_size, collate_fn=collator_fn) self._dp_rank = dp_rank self._rank_id = f"dp_rank_{dp_rank}" diff --git a/torchtitan/datasets/mm_datasets.py b/torchtitan/datasets/mm_datasets.py new file mode 100755 index 00000000..ad07ff86 --- /dev/null +++ b/torchtitan/datasets/mm_datasets.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset + +from torchtitan.datasets import DPAwareDataLoader +from torchtitan.datasets.multimodal import ( + format_obelics, + Llama3VisionTransform, + MultiModalCollator, +) +from torchtitan.datasets.tokenizer import Tokenizer +from torchtitan.logging import logger + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +# map from dataset name to a dataset repository on the HF hub +_supported_datasets = { + "OBELICS": "HuggingFaceM4/OBELICS", +} + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset. + + Args: + dataset_name (str): name of the dataset to load + tokenizer (Tokenizer): + Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + world_size (int): number of data parallel processes participating in training + rank (int): rank of the current data parallel process + infinite (bool): whether to loop infinitely over the dataset + + We currently ONLY support the OBELICS dataset + + Example use: + >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer) + >>> for batch in Dataloader(ds, batch_size=8): + print(f"Batch size: {len(batch)}") + Batch size: 8 + """ + + def __init__( + self, + dataset_name: str, + tokenizer: Tokenizer, + world_size: int = 1, + rank: int = 0, + infinite: bool = False, + ) -> None: + # Do NOT allow user to pass any other dataset which is not OBELICS + if dataset_name not in _supported_datasets: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(_supported_datasets.keys())}" + ) + + dataset_path = _supported_datasets[dataset_name] + logger.info(f"Preparing {dataset_name} dataset from {dataset_path}") + ds = load_dataset(dataset_path, split="train", streaming=True) + + # TODO: support shuffling + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, rank, world_size) + self._tokenizer = tokenizer + self.infinite = infinite + self.image_token = "<|image|>" # TODO(tj.solergibert) Hardcoded! + + self.format = format_obelics + self.transform = Llama3VisionTransform( + tokenizer=tokenizer, + tile_size=448, + patch_size=14, + max_num_tiles=4, + image_mean=(0.48145466, 0.4578275, 0.40821073), + image_std=(0.26862954, 0.26130258, 0.27577711), + ) + # NOTE(tj.solergibert) 560 for Instruct models, 448 for pretrain + # https://github.com/pytorch/torchtune/blob/0cc1b1f6a2a9c54ca640c4eb0a4d0b94ba94bb04/torchtune/models/llama3_2_vision/_model_builders.py#L92 + # https://huggingface.co/meta-llama/Llama-3.2-11B-Vision/blob/3f2e93603aaa5dd142f27d34b06dfa2b6e97b8be/preprocessor_config.json#L22 + + # variables for checkpointing + self._sample_idx = 0 + + def __iter__(self): + + while True: + for sample in self._get_data_iter(): + # Format sample into `Llama3VisionTransform` format + try: + processed_sample = self.format(sample, image_token=self.image_token) + except Exception: + continue + assert len(processed_sample["images"]) == processed_sample[ + "text" + ].count(self.image_token) + self._sample_idx += 1 + # Transform sample + processed_sample = self.transform(processed_sample) + yield processed_sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + # As skipping to the end throws an error in case of map-style dataset, return an empty iterator + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + return iter(self._data.skip(self._sample_idx)) + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + def state_dict(self): + return {"sample_idx": self._sample_idx} + + +def build_mm_data_loader( + dataset_name: str, + tokenizer: Tokenizer, + batch_size: int, + world_size, + rank, + infinite: bool = True, +): + mm_ds = MultiModalDataset(dataset_name, tokenizer, world_size, rank, infinite) + + collator = MultiModalCollator(padding_idx=0, pad_max_tiles=4) + + return DPAwareDataLoader(rank, mm_ds, batch_size=batch_size, collator_fn=collator) diff --git a/torchtitan/datasets/multimodal/__init__.py b/torchtitan/datasets/multimodal/__init__.py new file mode 100755 index 00000000..d3bcefef --- /dev/null +++ b/torchtitan/datasets/multimodal/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.datasets.multimodal.clip import CLIPPreprocess +from torchtitan.datasets.multimodal.collator import MultiModalCollator +from torchtitan.datasets.multimodal.llama3_transform import Llama3VisionTransform +from torchtitan.datasets.multimodal.utils import format_obelics +from torchtitan.datasets.multimodal.vision_attention_mask import ( + VisionCrossAttentionMask, +) + +__all__ = [ + "CLIPPreprocess", + "MultiModalCollator", + "Llama3VisionTransform", + "format_obelics", + "VisionCrossAttentionMask", +] diff --git a/torchtitan/datasets/multimodal/clip.py b/torchtitan/datasets/multimodal/clip.py new file mode 100755 index 00000000..0eba9d33 --- /dev/null +++ b/torchtitan/datasets/multimodal/clip.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +import torch + +import torchvision + +from PIL import Image +from torchtitan.datasets.multimodal.utils import ( + find_supported_resolutions, + get_canvas_best_fit, + resize_with_pad, + tile_crop, +) + +from torchtitan.logging import logger +from torchvision.transforms.v2 import functional as F + +# NOTE Inspired from torchtune.models.clip._transform.py +class CLIPPreprocess: + """ + This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it + based on the image aspect ratio and the number of image tiles we allow. + + The algorithm will NOT distort the image to fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. + + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image tiles, with side 224px, then: + + If ``resize_to_max_canvas=False``, then: + best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + image is NOT resized + image is padded (300, 800) -> 448,896 + Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) + + If ``resize_to_max_canvas=True``, then: + best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize + image is padded (448, 1194) -> (448, 1344) + Image is tiled 2x5, for a final output shape of (10, 3, 224, 224) + + Args: + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + + Examples: + >>> image_transform = CLIPImageTransform( + ... image_mean=None, + ... image_std=None, + ... tile_size=224, + ... possible_resolutions=None, + ... max_num_tiles=4, + ... resample="bilinear", + ... resize_to_max_canvas=True, + ...) + >>> # create random image + >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) + >>> image = PIL.Image.fromarray(image) + >>> output = image_transform(image) + >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] + torch.Size([2, 3, 224, 224]) + >>> output['ar'] # image best fits the canvas 224x448 + torch.tensor([1,2]) + """ + + def __init__( + self, + *, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + possible_resolutions: Optional[List[Tuple[int, int]]] = None, + tile_size: int = 224, + max_num_tiles: Optional[int] = 4, + dtype: torch.dtype = torch.bfloat16, + resample: str = "bilinear", + resize_to_max_canvas: bool = False, + ) -> None: + + # get_canvas_best_fit + assert ( + possible_resolutions is not None or max_num_tiles is not None + ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions=} and {max_num_tiles=}" + + # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles + if not possible_resolutions and max_num_tiles: + possible_resolutions = find_supported_resolutions( + max_num_tiles=max_num_tiles, tile_size=tile_size + ) + else: + possible_resolutions = possible_resolutions + + self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + logger.debug( + f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + ) + + self.resize_to_max_canvas = resize_to_max_canvas + + # normalize + assert (image_mean is None) == ( + image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" + self.mean = image_mean + self.std = image_std + + # resize_with_pad + self.max_size = None if resize_to_max_canvas else tile_size + self.dtype = dtype + self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + + # tile_crop + self.tile_size = tile_size + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply image decoding and transformations to the "image" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with an "image" field containing + a List[Message] to tokenize + inference (bool): Whether the template is being used for inference or not. + + Returns: + Mapping[str, Any]: The sample with an updated "image" filed and added + "aspect_ratio" field. + """ + image = sample["image"] + assert isinstance(image, Image.Image), "Input image must be a PIL image." + + # Make image torch.tensor((3, H, W), dtype=dtype), 0<=values<=1 + if hasattr(image, "mode") and image.mode == "RGBA": + image = image.convert("RGB") + image = F.to_image(image) + image = F.grayscale_to_rgb_image(image) + image = F.to_dtype(image, dtype=self.dtype, scale=True) + + # Find the best canvas to fit the image without distortion + best_resolution = get_canvas_best_fit( + image=image, + possible_resolutions=self.possible_resolutions, + resize_to_max_canvas=self.resize_to_max_canvas, + ) + + # resize without distortion + pad to fit best_resolution + image = resize_with_pad( + image=image, + target_size=best_resolution, + resample=self.resample, + max_size=self.max_size, + ) + + # Normalize + if self.mean: + image = F.normalize(image, mean=self.mean, std=self.std) + + # Divide the image into equally sized tiles + image = tile_crop(image=image, tile_size=self.tile_size) + + aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + + sample.update( + { + "image": image, + "aspect_ratio": aspect_ratio, + } + ) + + return sample diff --git a/torchtitan/datasets/multimodal/collator.py b/torchtitan/datasets/multimodal/collator.py new file mode 100755 index 00000000..1e4cc836 --- /dev/null +++ b/torchtitan/datasets/multimodal/collator.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from torch.nn.utils.rnn import pad_sequence + + +def padded_collate_sft( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = -100, # NOTE(tj.solergibert) Hardcoded! +) -> Dict[str, torch.Tensor]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, torch.Tensor]: Collated input and label tensors. + + Example: + >>> token_pairs = [ + >>> {"tokens": [1, 2, 3], "labels": [4, 5, 6]}, + >>> {"tokens": [7,], "labels": [10,]}, + >>> ] + >>> collated = padded_collate( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["tokens"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["labels"] + >>> tensor([[4, 5, 6], [10, -100, -100]]) + """ + input_ids = pad_sequence( + [torch.tensor(x["tokens"]) for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + labels = pad_sequence( + [torch.tensor(x["labels"]) for x in batch], + batch_first=True, + padding_value=ignore_idx, + ) + + input_ids_seq_len = input_ids.shape[-1] + labels_seq_len = labels.shape[-1] + + # Hack to pad correctly and not use max_seq_len, which is costly + if input_ids_seq_len > labels_seq_len: + labels = F.pad( + labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx + ) + elif labels_seq_len > input_ids_seq_len: + input_ids = F.pad( + input_ids, + (0, labels_seq_len - input_ids_seq_len), + value=padding_idx, + ) + return {"tokens": input_ids.long(), "labels": labels.long()} + + +# NOTE Inspired from torchtune.data._collate.py +@dataclass +class MultiModalCollator: + padding_idx: int = 0 + ignore_idx: int = -100 # NOTE(tj.solergibert) Hardcoded! + pad_max_tiles: Optional[int] = None + pad_max_images: Optional[int] = None + + def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """Pad a batch of text sequences, tiled image tensors, aspect ratios, + and cross attention masks. This can be used for both training and inference. + + ``batch`` is expected to be a list of sample dicts containing the following:: + - "tokens": List[int] of length text_seq_len, varies across samples + - "labels": List[int] of length text_seq_len, varies across samples + - "encoder_input": Dict[str, List[torch.Tensor]] + - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) + - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio + - "encoder_mask": List[Tensor], each with shape (text_seq_len, image_seq_len) + + Shape notation: + - c = channel dim + - h = height dim + - w = weight dim + + Note: + For each element in the batch, ``len(images) == len(encoder_mask) == len(aspect_ratio)``. + + This collater does the following: + (1) Pad text sequence and encoder mask to the longest sequence length in the batch + (2) Pad image tensors in the tile dimension with zeros to the largest number + of tiles in the batch + (3) Add empty images of zeros to samples up to max number of images in the batch + (4) Pad aspect ratios with (1,1) for all added padding images + + Args: + batch (List[Dict[str, Any]]): A list of sample dicts containing tokens, + labels, images, encoder_mask, and aspect_ratio. + padding_idx (int): Padding index for input token ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles + in the batch. Defaults to None. + pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images + in the batch. Defaults to None. + + Returns: + Dict[str, Tensor]: Collated tokens, labels, images, encoder_mask, aspect_ratio tensors. + - tokens: Tensor of shape (bsz, max_seq_len) + - labels: Tensor of shape (bsz, max_seq_len) + - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) + - encoder_mask: Tensor of shape (bsz, max_seq_len, tokens_per_tile * max_num_tiles * max_num_images) + - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) + + Raises: + ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image. + + Example: + >>> image_id = 1 + >>> tokens_per_tile = 5 + >>> c, h, w = 1, 1, 1 + >>> batch = [ + ... { + ... "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7], + ... "encoder_input": { + ... # One image with two tiles, one image with three tiles + ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], + ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], + ... }, + ... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) + ... "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)], + ... }, + ... { + ... "tokens": [1, 4], "labels": [8, 9], + ... "encoder_input": { + ... # One image with four tiles + ... "images": [torch.ones(4, c, h, w)], + ... "aspect_ratio": [torch.tensor([2, 2])], + ... }, + ... # Mask is shape (text_seq_len, tokens_per_tile * n_tiles) + ... "encoder_mask": [torch.ones(2, 5 * 4)], + ... }, + ... ] + >>> model_inputs = padded_collate_tiled_images_and_mask(batch=batch) + >>> print(model_inputs["tokens"]) + tensor([[1, 2, 1, 3], + [1, 4, 0, 0]]) + >>> print(model_inputs["labels"]) + tensor([[4, 5, 6, 7], + [8, 9, -100, -100]]) + >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) + torch.Size([2, 2, 4, 1, 1, 1]) + >>> print(model_inputs["encoder_mask"].shape) + >>> # (bsz, max_text_seq_len, tokens_per_tile * max_num_tiles * max_num_images) + torch.Size([2, 4, 40]) + >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) + torch.Size([2, 2, 2]) + >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample + tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) + """ + # Text tokens can be handled independently by existing collaters + text_only = [ + {"tokens": sample["tokens"], "labels": sample["labels"]} for sample in batch + ] + collated_text = padded_collate_sft(text_only, self.padding_idx, self.ignore_idx) + + max_seq_len = collated_text["tokens"].shape[-1] + bsz = len(batch) + + # TODO: Figure out how to make this more efficient or vectorized. Setting + # max_num_tiles beforehand will save one nested for loop but may incur more + # memory and compute costs in attention if max_num_tiles > batch_max_num_tiles + + if self.pad_max_tiles is None: + # Get max number of tiles in batch + max_num_tiles = max( + image.shape[0] + for sample in batch + for image in sample["encoder_input"]["images"] + ) + if self.pad_max_tiles < max_num_tiles: + raise ValueError( + f"More tiles in image {max_num_tiles}, than pad_max_tiles {self.pad_max_tiles}" + ) + max_num_tiles = self.pad_max_tiles + + # Second loop: pad images and masks to max number of tiles, max text seq len in batch + batch_images = [] + batch_masks = [] + batch_aspect_ratios = [] + token_len = [] # DEBUG(tj.solergibert) + image_len = [] # DEBUG(tj.solergibert) + tile_len = [] # DEBUG(tj.solergibert) + for sample in batch: + sample_images = [] + sample_masks = [] + token_len.append(len(sample["tokens"])) # DEBUG(tj.solergibert) + image_len.append( + len(sample["encoder_input"]["images"]) + ) # DEBUG(tj.solergibert) + tmp_tile_len = [] # DEBUG(tj.solergibert) + for image, mask in zip( + sample["encoder_input"]["images"], sample["encoder_mask"] + ): + # Single image in each sample has shape (n_tiles, c, h, w) + n_tiles = image.shape[0] + tmp_tile_len.append(n_tiles) # DEBUG(tj.solergibert) + # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len) + # where image_seq_len = n_tiles * tokens_per_tile + text_seq_len, image_seq_len = mask.shape + tokens_per_tile = image_seq_len // n_tiles + padding_tiles = max_num_tiles - n_tiles + + # Image should now have shape (max_num_tiles, c, h, w) + padded_image = F.pad( + image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 + ) + # Mask should now have shape (max_seq_len, max_image_seq_len), where + # max_image_seq_len = max_num_tiles * tokens_per_tile + padded_mask = F.pad( + mask, + ( + 0, + padding_tiles * tokens_per_tile, + 0, + max_seq_len - text_seq_len, + ), + value=0, + ) + + sample_images.append(padded_image) + sample_masks.append(padded_mask) + tile_len.append(tmp_tile_len) # DEBUG(tj.solergibert) + # Stack multiple images and masks per sample in num_images dimension + batch_images.append(torch.stack(sample_images)) + batch_masks.append(torch.stack(sample_masks)) + batch_aspect_ratios.append( + torch.stack(sample["encoder_input"]["aspect_ratio"]) + ) + # Finally, pad images, masks, aspect ratios to max number of images in batch + # (bsz, max_num_images, max_num_tiles, c, h, w) + collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) + # (bsz, max_num_images, max_seq_len, max_image_seq_len) + collated_masks = pad_sequence(batch_masks, batch_first=True, padding_value=0) + # (bsz, max_num_images, 2) + collated_aspect_ratios = pad_sequence( + batch_aspect_ratios, batch_first=True, padding_value=1 + ) + + # Concatenate masks for multiple images across image_seq_len dimension + concat_masks = collated_masks.view(bsz, max_seq_len, -1) + if self.pad_max_images is not None: + _, _, img_seq = concat_masks.shape + concat_masks = F.pad( + concat_masks, (0, self.pad_max_images * image_seq_len - img_seq) + ) + + batch_dict = { + "tokens": collated_text["tokens"], + "labels": collated_text["labels"], + "encoder_input": { + "images": collated_images, + "aspect_ratio": collated_aspect_ratios, + }, + "encoder_mask": concat_masks, + "token_len": torch.tensor(token_len), # DEBUG(tj.solergibert) + "image_len": torch.tensor(image_len), # DEBUG(tj.solergibert) + "tile_len": tile_len, # DEBUG(tj.solergibert) + } + + return batch_dict diff --git a/torchtitan/datasets/multimodal/llama3_transform.py b/torchtitan/datasets/multimodal/llama3_transform.py new file mode 100755 index 00000000..59d22a7d --- /dev/null +++ b/torchtitan/datasets/multimodal/llama3_transform.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Mapping, Optional, Tuple + +from torchtitan.datasets.tokenizer import Tokenizer + +from ..multimodal import CLIPPreprocess, VisionCrossAttentionMask + +# NOTE Inspired from torchtune.models.llama3_2_vision._transform.py +class Llama3VisionTransform: + """ + This class combines the transforms for the different modalities of Llama 3.2 Vision. It + performs the following transforms: + - Tokenizing the text field using :class:`torchtitan.datasets.tokenizer.titoken.TikTokenizer` + - Preprocessing the images for the CLIP encoder using :class:`torchtitan.datasets.multimodal.clip.ClipPreprocess` + - Generating the Vision Cross Attention mask for the Fused layers + using :class:`torchtitan.datasets.multimodal.utils.VisionCrossAttentionMask` + + Args: + tokenizer (Tokenizer): + Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + tile_size (int): Size of the tiles to divide the image into. + patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is + used to calculate the number of image embeddings per image. + max_num_tiles (int): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + image_mean (Optional[Tuple[float, float, float]]): Mean values of each channel, used for normalization. + image_std (Optional[Tuple[float, float, float]]): Standard deviations for each channel, used for normalization. + + Examples: + >>> model_transform = Llama3VisionTransform("/path/to/tokenizer.model", tile_size=224, patch_size=14) + >>> transformed_data = model_transform({"messages": user_message, "images": [img1, img2]}) + >>> print(transformed_data["tokens"]) + [1, 31587, 29644, 102, 2] + >>> print(transformed_data["images"][0].shape) + torch.Size([4, 3, 224, 224]) + """ + + def __init__( + self, + tokenizer: Tokenizer, + tile_size: int, + patch_size: int, + max_num_tiles: int = 4, + image_mean: Optional[Tuple[float, float, float]] = None, + image_std: Optional[Tuple[float, float, float]] = None, + ): + self.tokenizer = tokenizer + + self.transform_image = CLIPPreprocess( + image_mean=image_mean, + image_std=image_std, + tile_size=tile_size, + possible_resolutions=None, + max_num_tiles=max_num_tiles, + resample="bilinear", + resize_to_max_canvas=False, + ) + self.xattn_mask = VisionCrossAttentionMask( + tile_size=tile_size, + patch_size=patch_size, + image_token_id=128256, # TODO(tj.solergibert) Hardcoded? + max_num_tiles=max_num_tiles, + ) + + self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1) + # TODO(tj.solergibert) self.pad_id = self.tokenizer.pad_id + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Apply image decoding, transformations and tokenization to messages in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "messages" field. + + Returns: + Mapping[str, Any]: The transformed sample with the following fields: + - tokens: List[int] of tokenized messages + - encoder_input: Dict[str, Any] of transformed images + - encoder_mask: List[bool] of masks for the transformed images + """ + encoder_input = {"images": [], "aspect_ratio": []} + for image in sample["images"]: + out = self.transform_image({"image": image}) + encoder_input["images"].append(out["image"]) + encoder_input["aspect_ratio"].append(out["aspect_ratio"]) + + sample["encoder_input"] = encoder_input + sample = self.tokenizer.encode_multimodal(sample) + # TODO(tj.solergibert) What should we do (Include y/n & Mask y/n) with both bos & eos + # TODO(tj.solergibert) allowed_special to this fancy set OR set it to "all"? + sample = self.xattn_mask(sample) + return sample diff --git a/torchtitan/datasets/multimodal/utils.py b/torchtitan/datasets/multimodal/utils.py new file mode 100755 index 00000000..e0f60b0e --- /dev/null +++ b/torchtitan/datasets/multimodal/utils.py @@ -0,0 +1,458 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from collections import defaultdict + +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, Union +from urllib import request + +import torch +import torchvision +from torchvision.transforms.v2 import functional as F + +# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py +def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: + """ + Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. + + Args: + image (torch.Tensor): Input image to crop into tiles. + tile_size (int): Size of each tile. + + Returns: + torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> tiles = tile_crop(image, tile_size=50) + >>> tiles.shape # 4x6 = 24 tiles + torch.Size([24, 3, 50, 50]) + + >>> image = torch.rand(3, 400, 600) + >>> tiles = tile_crop(image, tile_size=200) + >>> tiles.shape # 2x3 = 6 tiles + torch.Size([6, 3, 200, 200]) + """ + + channel_size, height, width = image.shape + + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + + return tiles + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def resize_with_pad( + image: torch.Tensor, + target_size: Tuple[int, int], + resample: torchvision.transforms.InterpolationMode, + max_size: Optional[int] = None, +) -> torch.Tensor: + """ + Resizes and pads an image to target_size without causing distortion. + The user can set max_size to limit upscaling when target_size exceeds image_size. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images. + Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT, + InterpolationMode.BILINEAR and InterpolationMode.BICUBIC. + max_size (Optional[int]): The maximum size to upscale the image to. + If None, will upscale up to target_size. + + Returns: + torch.Tensor: The resized and padded image tensor in the format [..., H, W]. + + Examples: + + Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side, + and then padded from (448, 1194) to (448, 1344). + + >>> max_size = None + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344). + + >>> max_size = 600 + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 3: The image will be downscaled from (500, 1000) to (224, 448), + and padded from (224, 448) to (448, 448). + + >>> max_size = 600 + >>> image = torch.rand([3, 500, 1000]) + >>> target_size = (448, 488) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + """ + + image_height, image_width = image.shape[-2:] + image_size = (image_height, image_width) + + # If target_size requires upscaling, we might want to limit the upscaling to max_size + if max_size is not None: + new_target_height = min(max(image_height, max_size), target_size[0]) + new_target_width = min(max(image_width, max_size), target_size[1]) + target_size_resize = (new_target_height, new_target_width) + else: + target_size_resize = target_size + + # resize to target_size while preserving aspect ratio + new_size_preserving_aspect_ratio = _get_max_res_without_distortion( + image_size=image_size, + target_size=target_size_resize, + ) + + image = F.resize( + inpt=image, + size=list(new_size_preserving_aspect_ratio), + interpolation=resample, + antialias=True, + ) + + image = _pad_image_top_left(image=image, target_size=target_size) + + return image + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _pad_image_top_left( + image: torch.Tensor, + target_size: Tuple[int, int], +) -> torch.Tensor: + """ + Places the image at the top left of the canvas and pads with 0 the right and bottom + to fit to the target resolution. If target_size < image_size, it will crop the image. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + + Returns: + torch.Tensor: The padded image tensor in the format [..., H, W]. + """ + + image_size = image.shape[-2:] + + height, width = image_size + target_height, target_width = target_size + + pad_x = target_width - width + pad_y = target_height - height + + padding = [0, 0, pad_x, pad_y] + return F.pad(inpt=image, padding=padding) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + For example, if image_size = (200,400) and target_size = (600,800), + scale_h = 600/200 = 3 + scale_w = 800/400 = 2 + So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2 + + Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w + + Args: + image_size (Tuple[int, int]): The original resolution of the image. + target_size (Tuple[int, int]): The desired resolution to fit the image into. + Returns: + Tuple[int, int]: The optimal dimensions to which the image should be resized. + Examples: + >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200)) + (133, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300)) + (450, 337) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def _get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves no remainder. + + Args: + n (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + + Examples: + >>> _get_factors(n=12) + {1, 2, 3, 4, 6, 12} + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def get_canvas_best_fit( + image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to + resize an image to, without distortion. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x, + then the maximum upscaling without distortion is min(2, 1.5) = 1.5. + + If there are multiple canvases that satisfy the conditions, + we pick the one with the lowest area to minimize padding. + + Args: + image (torch.Tensor): The image we want to fit into a canvas. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible canvas. + resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling. + If False, pick the canvas that minimizes downscaling, including no downscaling at all. + + Returns: + Tuple[int, int]: The best resolution to fit the image into. + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> possible_resolutions = torch.tensor([ + ... [224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224] + ... ]) + >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False) + (224, 448) + + In the example above, we calculate the scaling factors for each possible resolution + + >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + + Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest + + >>> upscaling_options = torch.tensor([1.1200, 1.1200]) + >>> selected_scale = torch.tensor(1.1200) + + There are two possible options, so we pick the one with the smallest area + + >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively + >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area + """ + + original_height, original_width = image.shape[-2:] + + # possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get limiting side scaling -> no distortion + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def find_supported_resolutions( + max_num_tiles: int, tile_size: int +) -> List[Tuple[int, int]]: + """ + Computes all combinations of resolutions, multiple of tile_size, + that contain up to max_num_tiles. Useful for when dividing an image into tiles. + + For example, if we want at most 2 tiles per image, then we can support the + following resolutions: (1x1, 1x2, 2x1) * tile_size + + Args: + max_num_tiles (int): Maximum number of tiles. + tile_size (int): Size of the side of the tile. + + Returns: + List[Tuple[int, int]]: List of possible resolutions as tuples (height, width). + + Examples: + + >>> max_num_tiles = 4 + >>> tile_size = 224 + >>> find_supported_resolutions(max_num_tiles, tile_size) + [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)] + """ + + # create dictionary {aspect_ratio: [resolution1, ..., resolution n]} + # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]} + asp_dict = defaultdict(list) + for _tile_size in range(max_num_tiles, 0, -1): + factors = sorted(_get_factors(_tile_size)) + asp_ratios = [(factor, _tile_size // factor) for factor in factors] + for height, width in asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the tile_size + possible_resolutions = [] + for ar, resolution in asp_dict.items(): + for height, width in resolution: + possible_resolutions.append((height * tile_size, width * tile_size)) + + return possible_resolutions + + +# NOTE Copied from torchtune.data._utils.py +def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image": + """ + Convenience method to load an image in PIL format from a local file path or remote source. + + Args: + image_loc (Union[Path, str]): Local file path or remote source pointing to the image + which will be loaded in PIL format. + + Note: + If loading an image from a remote source, the function expects the URL provided in ``image_loc`` + to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". + + Raises: + ValueError: If the image cannot be loaded from remote source. + ValueError: If the image cannot be opened as a :class:`~PIL.Image.Image`. + + Examples: + >>> # Load from remote source + >>> image = load_image("https://www.wikipedia.org/en/bird.jpg") + + >>> # Load from local file path + >>> image = load_image(Path("/home/user/bird.jpg")) + + Returns: + PIL.Image.Image: The loaded image. + """ + # Hackily import PIL to avoid burdensome import in the main module + # TODO: Fix this + from PIL import Image + + # If pointing to remote source, try to load to local + if isinstance(image_loc, str) and image_loc.startswith("http"): + try: + image_loc = request.urlopen(image_loc) + except Exception as e: + raise ValueError(f"Failed to load image from {image_loc}") from e + + # Open the local image as a PIL image + try: + image = Image.open(image_loc) + except Exception as e: + raise ValueError(f"Failed to open image as PIL Image from {image_loc}") from e + + return image + + +def format_obelics(sample: Dict, image_token: str = "<|image|>") -> Dict: + """ + This function formats samples from the OBELICS dataset to be processed with `Llama3VisionTransform` + Returns: + Dict[str, Any]: The transformed sample with the following fields: + - images: List[PIL.Image.Image] with the loaded images + - text: str with the text of the sample ready to be tokenized including the image tokens + Example: + >>> formatted_sample = format_obelics(sample, image_token="<|image|>") + >>> print(formatted_sample["text"]) + ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :(" + """ + # TODO(tj.solergibert) Optimization: Drop images at the end as they are useless! + sample_images = [image for image in sample["images"] if image is not None] + sample_text = [ + text if text is not None else image_token for text in sample["texts"] + ] + return { + "images": [load_image(image) for image in sample_images], + "text": "".join(map(str, sample_text)), + } diff --git a/torchtitan/datasets/multimodal/vision_attention_mask.py b/torchtitan/datasets/multimodal/vision_attention_mask.py new file mode 100755 index 00000000..ebee7f14 --- /dev/null +++ b/torchtitan/datasets/multimodal/vision_attention_mask.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional + +import torch + +# NOTE Inspired from torchtune.modules.transforms._transforms.py +class VisionCrossAttentionMask: + """ + Computes the cross-attention mask for text + image inputs. Text tokens that + participate in cross-attention with an image token will show True in the mask + and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper + (https://arxiv.org/pdf/2204.14198): + + (1) Text tokens immediately following the image token up until the next image token + (2) Consecutive image tokens attend to subsequent text tokens + + :: + + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + These are two dogs. This is a cat. + + + + Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), + where True indicates that the token outputted from the image encoder attends + to the token in the text sequence in cross-attention. A list of these masks + are returned with length equal to number of images in the sample. + + Args: + tile_size (int): The size of the image tiles from the image transform + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + image_token_id (int): Token ID of the image special token. + max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to + pad mask during inference. Defaults to None + """ + + def __init__( + self, + tile_size: int, + patch_size: int, + image_token_id: int, + max_num_tiles: Optional[int] = None, + ): + patch_grid_size = tile_size // patch_size + self.patches_per_tile = patch_grid_size**2 + self.image_token_id = image_token_id + self.max_num_tiles = max_num_tiles + + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: + """ + Returns a list of lists of the form [start, end) where start is the index + of the current image token and end is the index of the next image token, exclusive. + + Args: + tokens (List[int]): List of token IDs in the text sequence + + Returns: + List[List[int]]: List of lists of the form [start, end) indicating + range of positions in text sequence that should attend to the image + + Example: + >>> text = "These are two dogs. This is a cat." + >>> image_token_id = 1 + >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] + >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) + >>> intervals = transform._get_image_attention_intervals(tokens) + >>> print(intervals) + [[0, 7], [1, 7], [7, 12]] + """ + end = len(tokens) + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == self.image_token_id + ] + # Return empty list if there are no images + if len(vision_token_locations) == 0: + return [] + # If there is only one image, it will attend to subsequent text until end + if len(vision_token_locations) == 1: + return [[vision_token_locations[0], end]] + + # Construct intervals from previous image token to next image token + vision_masks = [ + [tok_idx_prev, tok_idx_next] + # Offset by one to get consecutive indices + for tok_idx_prev, tok_idx_next in zip( + vision_token_locations[:-1], vision_token_locations[1:] + ) + ] + # Last image will attend to subsequent text until end + vision_masks.append([vision_token_locations[-1], end]) + + # If there are consecutive vision tokens, they should all attend to the + # same subsequent text + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + return vision_masks + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Generates the vision cross-attention mask for the given sample based on + the image token locations interleaved in the text sequence. + + Args: + sample (Mapping[str, Any]): Sample dict containing the following keys: + - tokens (List[int]): List of token IDs in the text sequence. Number of + image token IDs in the sequence must match the number of images. + - images (List[torch.Tensor]): List of image Tensors post-tiling of shape + (n_tiles, c, h, w) each. + + Returns: + Mapping[str, Any]: sample with a new key encoder_mask, with a mask per image with shape + (text_seq_len, image_seq_len) where text_seq_len == len(tokens) and + image_seq_len == max_tiles * (patches_per_tile + 1). These masks get padded and concatenated + in the batch collator. + + Raises: + RuntimeError: if the number of images in the batch does not match the number of image tokens in the batch. + """ + tokens, images = sample["tokens"], sample["encoder_input"]["images"] + # One sample can have multiple images - verify the number of image tokens + # is the same + n_img = len(images) + intervals = self._get_image_attention_intervals(tokens) + if len(intervals) != n_img: + raise RuntimeError( + f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." + ) + + # Create mask for each individual image based on its number of tokens, + # which can vary based on number of tiles since they are not yet tile padded. + # The masks are padded and concatenated together in the batch collator + text_seq_len = len(tokens) + max_image_size = None + masks = [] + for image_num, interval in enumerate(intervals): + # Identify what part of text sequence should be attended + start, end = interval + # Compute this image's number of tokens based on num tiles, patches per tile + n_tiles = images[image_num].shape[0] + image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token + # Mask will be block of 1s at the corresponding interval in the text. + # It is not a causal block because all the image tokens correspond + # to a single image, so text tokens attend to all the image's tokens. + # The mask is text_seq_len x mask_image_size if defined, otherwise + # it uses current text/image sequence lengths. + mask = torch.zeros( + text_seq_len, max_image_size or image_seq_len, dtype=torch.bool + ) + mask[start:end, :image_seq_len] = True + masks.append(mask) + + sample.update({"encoder_mask": masks}) + return sample diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py old mode 100644 new mode 100755 index c879e7f3..dc68af58 --- a/torchtitan/datasets/tokenizer/tiktoken.py +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -11,17 +11,20 @@ from pathlib import Path from typing import ( AbstractSet, + Any, cast, Collection, Dict, Iterator, List, Literal, + Mapping, Optional, Sequence, Union, ) +import numpy as np import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -67,6 +70,7 @@ def __init__(self, model_path: str): self.special_tokens = { token: num_base_tokens + i for i, token in enumerate(special_tokens) } + self.special_tokens["<|image|>"] = 128256 # TODO(tj.solergibert) Hardcoded! self.model = tiktoken.Encoding( name=Path(model_path).name, pat_str=self.pat_str, @@ -79,6 +83,7 @@ def __init__(self, model_path: str): self.bos_id: int = self.special_tokens["<|begin_of_text|>"] self.eos_id: int = self.special_tokens["<|end_of_text|>"] self.pad_id: int = -1 + self.image_id = 128256 # TODO(tj.solergibert) Hardcoded! self.stop_tokens = { self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"], @@ -190,3 +195,30 @@ def _split_whitespaces_or_nonwhitespaces( slice_start = i current_slice_len = 1 yield s[slice_start:] + + def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: + """ + Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens. + """ + # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator? + # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder` + # & everything else expects `tokens` + text = sample["text"] + tokens = self.encode( + text, bos=True, eos=True, allowed_special=set(["<|image|>"]) + ) + input_ids = tokens[:-1] + labels = tokens[1:] + labels = list( + np.where( + np.isin(labels, [self.bos_id, self.eos_id, self.image_id]), + -100, # TODO(tj.solergibert) Hardcoded! + labels, + ) + ) + + assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete + + sample.update({"tokens": input_ids, "labels": labels}) + + return sample From 07a7a12af64075225874b18c99ab158776c2d254 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 30 Oct 2024 17:41:55 +0100 Subject: [PATCH 2/3] Added padding check script --- scripts/check_padding_mm.py | 70 +++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100755 scripts/check_padding_mm.py diff --git a/scripts/check_padding_mm.py b/scripts/check_padding_mm.py new file mode 100755 index 00000000..948c83c0 --- /dev/null +++ b/scripts/check_padding_mm.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.datasets import build_mm_data_loader, build_tokenizer + +PATH_TO_TOKENIZER = "/workspace/mm/tokenizer.model" +BATCH_SIZE = 16 +BATCH_NUMBER = 4 + + +def main(): + tokenizer = build_tokenizer("tiktoken", PATH_TO_TOKENIZER) + dl = build_mm_data_loader("OBELICS", tokenizer, BATCH_SIZE, 2, 0) + dl_iter = iter(dl) + for _ in range(BATCH_NUMBER): + batch = next(dl_iter) + + # Analyze Batch + ## input_ids + total_input_ids = sum(batch["token_len"].tolist()) + input_ids_pad_length = max(batch["token_len"].tolist()) + total_tokens_in_batch = input_ids_pad_length * BATCH_SIZE + total_input_ids_padded_tokens = sum( + (batch["token_len"] - input_ids_pad_length) * -1 + ) + print( + f"Unpadded tokens: {total_input_ids}, Total tokens in batch: {total_tokens_in_batch}" + ) + print( + f"Padded text tokens: {total_input_ids_padded_tokens}, {total_input_ids_padded_tokens/total_tokens_in_batch*100:.2f}%" + ) + print(40 * "#") + ## image_ids + total_images = sum(batch["image_len"].tolist()) + image_pad_length = max(batch["image_len"].tolist()) + total_images_in_batch = image_pad_length * BATCH_SIZE + total_images_padded_tokens = sum((batch["image_len"] - image_pad_length) * -1) + print( + f"Unpadded images: {total_images}, Total images in batch: {total_images_in_batch}" + ) + print( + f'Padded images: {total_images_padded_tokens}, {total_images_padded_tokens/total_images_in_batch*100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0,0].shape)})' + ) + print(40 * "#") + # Tiles + total_number_of_tiles = sum([sum(sample) for sample in batch["tile_len"]]) + print( + f"Unpadded number of tiles: {total_number_of_tiles}, Total number of tiles: {total_images_in_batch*4}" + ) + print( + f'Padded tiles: {total_images_in_batch*4-total_number_of_tiles}, {(1-(total_number_of_tiles/(total_images_in_batch*4-total_number_of_tiles)))*100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0,0,0].shape)})' + ) + print(40 * "#") + # CrossAttentionMask + original_cross_attention_mask_elements = ( + total_number_of_tiles * 1025 * total_input_ids + ) # NOTE(tj.solergibert) We have 1024+1 image tokens per tile + print( + f"Unpadded cross attention mask elements: {original_cross_attention_mask_elements}, Total cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch}" + ) # TODO(tj.solergibert) Each element is a `bool` + print( + f"Padded cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements}, {100*((total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements)/(total_images_in_batch*4*1025*total_tokens_in_batch)):.2f}%" + ) + + +if __name__ == "__main__": + main() From 9a02575a6f0c703ef8f7f95602570a05ee1c3128 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 6 Nov 2024 13:41:03 +0100 Subject: [PATCH 3/3] Added torchvision dependency --- .ci/docker/requirements.txt | 1 + .pre-commit-config.yaml | 51 ------------------------------------- pyproject.toml | 3 +++ 3 files changed, 4 insertions(+), 51 deletions(-) mode change 100644 => 100755 .ci/docker/requirements.txt delete mode 100755 .pre-commit-config.yaml mode change 100644 => 100755 pyproject.toml diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt old mode 100644 new mode 100755 index 2321627e..f7b0979e --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -6,3 +6,4 @@ sentencepiece tiktoken blobfile tabulate +torchvision diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100755 index 318f7ef2..00000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,51 +0,0 @@ -exclude: 'build' - -default_language_version: - python: python3 - -repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: 6306a48f7dae5861702d573c9c247e4e9498e867 - hooks: - - id: trailing-whitespace - - id: check-ast - - id: check-merge-conflict - - id: no-commit-to-branch - args: ['--branch=main'] - - id: check-added-large-files - args: ['--maxkb=500'] - - id: end-of-file-fixer - exclude: '^(.*\.svg)$' - -- repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.4 - hooks: - - id: insert-license - files: \.py$ - args: - - --license-filepath - - docs/license_header.txt - -- repo: https://github.com/pycqa/flake8 - rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b - hooks: - - id: flake8 - additional_dependencies: - - flake8-bugbear == 22.4.25 - - pep8-naming == 0.12.1 - - torchfix - args: ['--config=.flake8'] - -- repo: https://github.com/omnilib/ufmt - rev: v2.3.0 - hooks: - - id: ufmt - additional_dependencies: - - black == 22.12.0 - - usort == 1.0.5 - -- repo: https://github.com/jsh9/pydoclint - rev: d88180a8632bb1602a4d81344085cf320f288c5a - hooks: - - id: pydoclint - args: [--config=pyproject.toml] diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 index a5c1b72f..484f239e --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,9 @@ dependencies = [ "sentencepiece", "tiktoken", + # Multimodality + "torchvision", + # Miscellaneous "tomli>=1.1.0" ]