Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding OBELICS DataLoader #663

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ sentencepiece
tiktoken
blobfile
tabulate
torchvision
51 changes: 0 additions & 51 deletions .pre-commit-config.yaml

This file was deleted.

3 changes: 3 additions & 0 deletions pyproject.toml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ dependencies = [
"sentencepiece",
"tiktoken",

# Multimodality
"torchvision",

# Miscellaneous
"tomli>=1.1.0"
]
Expand Down
70 changes: 70 additions & 0 deletions scripts/check_padding_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets import build_mm_data_loader, build_tokenizer

PATH_TO_TOKENIZER = "/workspace/mm/tokenizer.model"
BATCH_SIZE = 16
BATCH_NUMBER = 4


def main():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can make this as a unit test? WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add as a unit test some checks of shapes & types on the DP axis rather than this script that just checks the amount of padding in each batch

tokenizer = build_tokenizer("tiktoken", PATH_TO_TOKENIZER)
dl = build_mm_data_loader("OBELICS", tokenizer, BATCH_SIZE, 2, 0)
dl_iter = iter(dl)
for _ in range(BATCH_NUMBER):
batch = next(dl_iter)

# Analyze Batch
## input_ids
total_input_ids = sum(batch["token_len"].tolist())
input_ids_pad_length = max(batch["token_len"].tolist())
total_tokens_in_batch = input_ids_pad_length * BATCH_SIZE
total_input_ids_padded_tokens = sum(
(batch["token_len"] - input_ids_pad_length) * -1
)
print(
f"Unpadded tokens: {total_input_ids}, Total tokens in batch: {total_tokens_in_batch}"
)
print(
f"Padded text tokens: {total_input_ids_padded_tokens}, {total_input_ids_padded_tokens/total_tokens_in_batch*100:.2f}%"
)
print(40 * "#")
## image_ids
total_images = sum(batch["image_len"].tolist())
image_pad_length = max(batch["image_len"].tolist())
total_images_in_batch = image_pad_length * BATCH_SIZE
total_images_padded_tokens = sum((batch["image_len"] - image_pad_length) * -1)
print(
f"Unpadded images: {total_images}, Total images in batch: {total_images_in_batch}"
)
print(
f'Padded images: {total_images_padded_tokens}, {total_images_padded_tokens/total_images_in_batch*100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0,0].shape)})'
)
print(40 * "#")
# Tiles
total_number_of_tiles = sum([sum(sample) for sample in batch["tile_len"]])
print(
f"Unpadded number of tiles: {total_number_of_tiles}, Total number of tiles: {total_images_in_batch*4}"
)
print(
f'Padded tiles: {total_images_in_batch*4-total_number_of_tiles}, {(1-(total_number_of_tiles/(total_images_in_batch*4-total_number_of_tiles)))*100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0,0,0].shape)})'
)
print(40 * "#")
# CrossAttentionMask
original_cross_attention_mask_elements = (
total_number_of_tiles * 1025 * total_input_ids
) # NOTE(tj.solergibert) We have 1024+1 image tokens per tile
print(
f"Unpadded cross attention mask elements: {original_cross_attention_mask_elements}, Total cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch}"
) # TODO(tj.solergibert) Each element is a `bool`
print(
f"Padded cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements}, {100*((total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements)/(total_images_in_batch*4*1025*total_tokens_in_batch)):.2f}%"
)


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion torchtitan/datasets/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import build_hf_data_loader, DPAwareDataLoader
from torchtitan.datasets.mm_datasets import build_mm_data_loader
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"build_tokenizer",
"DPAwareDataLoader",
"build_mm_data_loader",
]
12 changes: 9 additions & 3 deletions torchtitan/datasets/hf_datasets.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.distributed.checkpoint.stateful import Stateful
Expand Down Expand Up @@ -156,8 +156,14 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful):
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
"""

def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
super().__init__(hf_ds, batch_size)
def __init__(
self,
dp_rank: int,
hf_ds: IterableDataset,
batch_size: int,
collator_fn: Callable,
):
super().__init__(dataset=hf_ds, batch_size=batch_size, collate_fn=collator_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this collate_fn being used or called?

self._dp_rank = dp_rank
self._rank_id = f"dp_rank_{dp_rank}"

Expand Down
144 changes: 144 additions & 0 deletions torchtitan/datasets/mm_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

from torchtitan.datasets import DPAwareDataLoader
from torchtitan.datasets.multimodal import (
format_obelics,
Llama3VisionTransform,
MultiModalCollator,
)
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a dataset repository on the HF hub
_supported_datasets = {
"OBELICS": "HuggingFaceM4/OBELICS",
}


class MultiModalDataset(IterableDataset, Stateful):
"""PyTorch MultiModal Dataset.

Args:
dataset_name (str): name of the dataset to load
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset

We currently ONLY support the OBELICS dataset

Example use:
>>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer)
>>> for batch in Dataloader(ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""

def __init__(
self,
dataset_name: str,
tokenizer: Tokenizer,
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
) -> None:
# Do NOT allow user to pass any other dataset which is not OBELICS
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}"
)

dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")
ds = load_dataset(dataset_path, split="train", streaming=True)

# TODO: support shuffling
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.infinite = infinite
self.image_token = "<|image|>" # TODO(tj.solergibert) Hardcoded!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this hardcoding should be fine?


self.format = format_obelics
self.transform = Llama3VisionTransform(
tokenizer=tokenizer,
tile_size=448,
patch_size=14,
max_num_tiles=4,
image_mean=(0.48145466, 0.4578275, 0.40821073),
image_std=(0.26862954, 0.26130258, 0.27577711),
)
# NOTE(tj.solergibert) 560 for Instruct models, 448 for pretrain
# https://github.com/pytorch/torchtune/blob/0cc1b1f6a2a9c54ca640c4eb0a4d0b94ba94bb04/torchtune/models/llama3_2_vision/_model_builders.py#L92
# https://huggingface.co/meta-llama/Llama-3.2-11B-Vision/blob/3f2e93603aaa5dd142f27d34b06dfa2b6e97b8be/preprocessor_config.json#L22

# variables for checkpointing
self._sample_idx = 0

def __iter__(self):

while True:
for sample in self._get_data_iter():
# Format sample into `Llama3VisionTransform` format
try:
processed_sample = self.format(sample, image_token=self.image_token)
except Exception:
continue
assert len(processed_sample["images"]) == processed_sample[
"text"
].count(self.image_token)
self._sample_idx += 1
# Transform sample
processed_sample = self.transform(processed_sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you get this from TorchTune, can we make the name not like transformer? Maybe prepoc? Because we will have transform or transformer in later stage of the model as well. And this part is not trainable, so to better differentiate, could you give it a different name?

yield processed_sample

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]

def state_dict(self):
return {"sample_idx": self._sample_idx}


def build_mm_data_loader(
dataset_name: str,
tokenizer: Tokenizer,
batch_size: int,
world_size,
rank,
infinite: bool = True,
):
mm_ds = MultiModalDataset(dataset_name, tokenizer, world_size, rank, infinite)

collator = MultiModalCollator(padding_idx=0, pad_max_tiles=4)

return DPAwareDataLoader(rank, mm_ds, batch_size=batch_size, collator_fn=collator)
21 changes: 21 additions & 0 deletions torchtitan/datasets/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.multimodal.clip import CLIPPreprocess
from torchtitan.datasets.multimodal.collator import MultiModalCollator
from torchtitan.datasets.multimodal.llama3_transform import Llama3VisionTransform
from torchtitan.datasets.multimodal.utils import format_obelics
from torchtitan.datasets.multimodal.vision_attention_mask import (
VisionCrossAttentionMask,
)

__all__ = [
"CLIPPreprocess",
"MultiModalCollator",
"Llama3VisionTransform",
"format_obelics",
"VisionCrossAttentionMask",
]
Loading