-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding OBELICS DataLoader #663
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ sentencepiece | |
tiktoken | ||
blobfile | ||
tabulate | ||
torchvision |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.datasets import build_mm_data_loader, build_tokenizer | ||
|
||
PATH_TO_TOKENIZER = "/workspace/mm/tokenizer.model" | ||
BATCH_SIZE = 16 | ||
BATCH_NUMBER = 4 | ||
|
||
|
||
def main(): | ||
tokenizer = build_tokenizer("tiktoken", PATH_TO_TOKENIZER) | ||
dl = build_mm_data_loader("OBELICS", tokenizer, BATCH_SIZE, 2, 0) | ||
dl_iter = iter(dl) | ||
for _ in range(BATCH_NUMBER): | ||
batch = next(dl_iter) | ||
|
||
# Analyze Batch | ||
## input_ids | ||
total_input_ids = sum(batch["token_len"].tolist()) | ||
input_ids_pad_length = max(batch["token_len"].tolist()) | ||
total_tokens_in_batch = input_ids_pad_length * BATCH_SIZE | ||
total_input_ids_padded_tokens = sum( | ||
(batch["token_len"] - input_ids_pad_length) * -1 | ||
) | ||
print( | ||
f"Unpadded tokens: {total_input_ids}, Total tokens in batch: {total_tokens_in_batch}" | ||
) | ||
print( | ||
f"Padded text tokens: {total_input_ids_padded_tokens}, {total_input_ids_padded_tokens/total_tokens_in_batch*100:.2f}%" | ||
) | ||
print(40 * "#") | ||
## image_ids | ||
total_images = sum(batch["image_len"].tolist()) | ||
image_pad_length = max(batch["image_len"].tolist()) | ||
total_images_in_batch = image_pad_length * BATCH_SIZE | ||
total_images_padded_tokens = sum((batch["image_len"] - image_pad_length) * -1) | ||
print( | ||
f"Unpadded images: {total_images}, Total images in batch: {total_images_in_batch}" | ||
) | ||
print( | ||
f'Padded images: {total_images_padded_tokens}, {total_images_padded_tokens/total_images_in_batch*100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0,0].shape)})' | ||
) | ||
print(40 * "#") | ||
# Tiles | ||
total_number_of_tiles = sum([sum(sample) for sample in batch["tile_len"]]) | ||
print( | ||
f"Unpadded number of tiles: {total_number_of_tiles}, Total number of tiles: {total_images_in_batch*4}" | ||
) | ||
print( | ||
f'Padded tiles: {total_images_in_batch*4-total_number_of_tiles}, {(1-(total_number_of_tiles/(total_images_in_batch*4-total_number_of_tiles)))*100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0,0,0].shape)})' | ||
) | ||
print(40 * "#") | ||
# CrossAttentionMask | ||
original_cross_attention_mask_elements = ( | ||
total_number_of_tiles * 1025 * total_input_ids | ||
) # NOTE(tj.solergibert) We have 1024+1 image tokens per tile | ||
print( | ||
f"Unpadded cross attention mask elements: {original_cross_attention_mask_elements}, Total cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch}" | ||
) # TODO(tj.solergibert) Each element is a `bool` | ||
print( | ||
f"Padded cross attention mask elements: {total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements}, {100*((total_images_in_batch*4*1025*total_tokens_in_batch-original_cross_attention_mask_elements)/(total_images_in_batch*4*1025*total_tokens_in_batch)):.2f}%" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import pickle | ||
from typing import Any, Dict, List, Optional | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import torch | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
|
@@ -156,8 +156,14 @@ class DPAwareDataLoader(StatefulDataLoader, Stateful): | |
A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank. | ||
""" | ||
|
||
def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): | ||
super().__init__(hf_ds, batch_size) | ||
def __init__( | ||
self, | ||
dp_rank: int, | ||
hf_ds: IterableDataset, | ||
batch_size: int, | ||
collator_fn: Callable, | ||
): | ||
super().__init__(dataset=hf_ds, batch_size=batch_size, collate_fn=collator_fn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is this |
||
self._dp_rank = dp_rank | ||
self._rank_id = f"dp_rank_{dp_rank}" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.utils.data import IterableDataset | ||
|
||
from torchtitan.datasets import DPAwareDataLoader | ||
from torchtitan.datasets.multimodal import ( | ||
format_obelics, | ||
Llama3VisionTransform, | ||
MultiModalCollator, | ||
) | ||
from torchtitan.datasets.tokenizer import Tokenizer | ||
from torchtitan.logging import logger | ||
|
||
from datasets import Dataset, load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
|
||
# map from dataset name to a dataset repository on the HF hub | ||
_supported_datasets = { | ||
"OBELICS": "HuggingFaceM4/OBELICS", | ||
} | ||
|
||
|
||
class MultiModalDataset(IterableDataset, Stateful): | ||
"""PyTorch MultiModal Dataset. | ||
|
||
Args: | ||
dataset_name (str): name of the dataset to load | ||
tokenizer (Tokenizer): | ||
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. | ||
world_size (int): number of data parallel processes participating in training | ||
rank (int): rank of the current data parallel process | ||
infinite (bool): whether to loop infinitely over the dataset | ||
|
||
We currently ONLY support the OBELICS dataset | ||
|
||
Example use: | ||
>>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer) | ||
>>> for batch in Dataloader(ds, batch_size=8): | ||
print(f"Batch size: {len(batch)}") | ||
Batch size: 8 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset_name: str, | ||
tokenizer: Tokenizer, | ||
world_size: int = 1, | ||
rank: int = 0, | ||
infinite: bool = False, | ||
) -> None: | ||
# Do NOT allow user to pass any other dataset which is not OBELICS | ||
if dataset_name not in _supported_datasets: | ||
raise ValueError( | ||
f"Dataset {dataset_name} is not supported. " | ||
f"Supported datasets are: {list(_supported_datasets.keys())}" | ||
) | ||
|
||
dataset_path = _supported_datasets[dataset_name] | ||
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}") | ||
ds = load_dataset(dataset_path, split="train", streaming=True) | ||
|
||
# TODO: support shuffling | ||
self.dataset_name = dataset_name | ||
self._data = split_dataset_by_node(ds, rank, world_size) | ||
self._tokenizer = tokenizer | ||
self.infinite = infinite | ||
self.image_token = "<|image|>" # TODO(tj.solergibert) Hardcoded! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this hardcoding should be fine? |
||
|
||
self.format = format_obelics | ||
self.transform = Llama3VisionTransform( | ||
tokenizer=tokenizer, | ||
tile_size=448, | ||
patch_size=14, | ||
max_num_tiles=4, | ||
image_mean=(0.48145466, 0.4578275, 0.40821073), | ||
image_std=(0.26862954, 0.26130258, 0.27577711), | ||
) | ||
# NOTE(tj.solergibert) 560 for Instruct models, 448 for pretrain | ||
# https://github.com/pytorch/torchtune/blob/0cc1b1f6a2a9c54ca640c4eb0a4d0b94ba94bb04/torchtune/models/llama3_2_vision/_model_builders.py#L92 | ||
# https://huggingface.co/meta-llama/Llama-3.2-11B-Vision/blob/3f2e93603aaa5dd142f27d34b06dfa2b6e97b8be/preprocessor_config.json#L22 | ||
|
||
# variables for checkpointing | ||
self._sample_idx = 0 | ||
|
||
def __iter__(self): | ||
|
||
while True: | ||
for sample in self._get_data_iter(): | ||
# Format sample into `Llama3VisionTransform` format | ||
try: | ||
processed_sample = self.format(sample, image_token=self.image_token) | ||
except Exception: | ||
continue | ||
assert len(processed_sample["images"]) == processed_sample[ | ||
"text" | ||
].count(self.image_token) | ||
self._sample_idx += 1 | ||
# Transform sample | ||
processed_sample = self.transform(processed_sample) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know you get this from TorchTune, can we make the name not like transformer? Maybe prepoc? Because we will have transform or transformer in later stage of the model as well. And this part is not trainable, so to better differentiate, could you give it a different name? |
||
yield processed_sample | ||
|
||
if not self.infinite: | ||
logger.warning(f"Dataset {self.dataset_name} has run out of data") | ||
break | ||
else: | ||
# Reset offset for the next iteration | ||
self._sample_idx = 0 | ||
logger.warning(f"Dataset {self.dataset_name} is being re-looped") | ||
|
||
def _get_data_iter(self): | ||
if self._sample_idx == 0: | ||
return iter(self._data) | ||
|
||
# As skipping to the end throws an error in case of map-style dataset, return an empty iterator | ||
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): | ||
return iter([]) | ||
|
||
return iter(self._data.skip(self._sample_idx)) | ||
|
||
def load_state_dict(self, state_dict): | ||
self._sample_idx = state_dict["sample_idx"] | ||
|
||
def state_dict(self): | ||
return {"sample_idx": self._sample_idx} | ||
|
||
|
||
def build_mm_data_loader( | ||
dataset_name: str, | ||
tokenizer: Tokenizer, | ||
batch_size: int, | ||
world_size, | ||
rank, | ||
infinite: bool = True, | ||
): | ||
mm_ds = MultiModalDataset(dataset_name, tokenizer, world_size, rank, infinite) | ||
|
||
collator = MultiModalCollator(padding_idx=0, pad_max_tiles=4) | ||
|
||
return DPAwareDataLoader(rank, mm_ds, batch_size=batch_size, collator_fn=collator) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.datasets.multimodal.clip import CLIPPreprocess | ||
from torchtitan.datasets.multimodal.collator import MultiModalCollator | ||
from torchtitan.datasets.multimodal.llama3_transform import Llama3VisionTransform | ||
from torchtitan.datasets.multimodal.utils import format_obelics | ||
from torchtitan.datasets.multimodal.vision_attention_mask import ( | ||
VisionCrossAttentionMask, | ||
) | ||
|
||
__all__ = [ | ||
"CLIPPreprocess", | ||
"MultiModalCollator", | ||
"Llama3VisionTransform", | ||
"format_obelics", | ||
"VisionCrossAttentionMask", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can make this as a unit test? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add as a unit test some checks of shapes & types on the DP axis rather than this script that just checks the amount of padding in each batch