Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/unit_tests/test_dataset_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def test_load_dataset(self):
for world_size in [2]:
for rank in range(world_size):
dataset_name = "cc12m-test-iterable"
batch_size = 1

batch_size = 1
num_steps = 15
num_workers = 4

# TODO: if num_steps * batch_size * world_size is larger than the number of samples
# in the dataset, then the test will fail, due to huggingface's
Expand All @@ -64,6 +65,8 @@ def test_load_dataset(self):
dataset_name,
"--training.local_batch_size",
str(batch_size),
"--training.dataloader.num_workers",
str(num_workers),
"--training.classifier_free_guidance_prob",
"0.447",
"--training.test_mode",
Expand All @@ -82,6 +85,8 @@ def test_load_dataset(self):
infinite=True,
)

assert dl.num_workers == num_workers

it = iter(dl)

for i in range(0, num_steps):
Expand Down
23 changes: 21 additions & 2 deletions torchtitan/components/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
collate_fn: Optional function to collate samples in a batch.
collate_fn (Callable, optional): A function that takes a list of samples from the
dataset and collates them into a batch. Defaults to ``None``.
num_workers: Number of worker processes for data loading. Defaults to 0.
persistent_workers: If True, keep workers alive between dataset iterations.
Only applicable when num_workers > 0. Defaults to False.
prefetch_factor: Number of batches to prefetch per worker. Only applicable
when num_workers > 0. Defaults to None (uses PyTorch default of 2).
pin_memory: If True, copy tensors to CUDA pinned memory. Defaults to False.
"""

dp_rank: int
Expand All @@ -67,11 +74,23 @@ def __init__(
dp_world_size: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help change this: let's keep at most one positional arg (dataset) and others to be kwargs.

batch_size: int,
collate_fn: Callable | None = None,
num_workers: int = 0,
persistent_workers: bool = False,
prefetch_factor: int | None = None,
pin_memory: bool = False,
):
self.dp_world_size = dp_world_size
self.dp_rank = dp_rank
self.batch_size = batch_size
super().__init__(dataset, batch_size, collate_fn=collate_fn)
super().__init__(
dataset,
batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
pin_memory=pin_memory,
Comment on lines +89 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc).

These are valid concerns. For now I'm leaning towards keeping things simple by passing **kwargs around.

Does it make sense if we only make these args explicit when sending to the actual init of StatefulDataLoader and not passing in all **kwargs from the input of ParallelAwareDataloader? The point is to not accidentally hit error inside StatefulDataLoader.

)
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> dict[str, Any]:
Expand Down
31 changes: 31 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,31 @@ class LRScheduler:
"""


@dataclass
class DataLoader:
"""
Configuration for PyTorch DataLoader settings.
"""

num_workers: int = 0
"""Number of worker processes for data loading. 0 means data will be loaded in the main process."""

persistent_workers: bool = False
"""
If True, the data loader will not shutdown the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Only applicable when num_workers > 0.
"""

prefetch_factor: int | None = None
"""
Number of batches loaded in advance by each worker. If None, the default value (2) is used.
Only applicable when num_workers > 0.
"""

pin_memory: bool = False
"""If True, the data loader will copy Tensors into CUDA pinned memory before returning them."""


@dataclass
class Training:
dataset: str = "c4_test"
Expand Down Expand Up @@ -263,6 +288,9 @@ class Training:
many temporary files.
"""

dataloader: DataLoader = field(default_factory=DataLoader)
"""DataLoader configuration"""


@dataclass
class Parallelism:
Expand Down Expand Up @@ -908,6 +936,9 @@ class Validation:
WARNING: When setting to -1 there could be hangs due to mismatch among ranks
"""

dataloader: DataLoader = field(default_factory=DataLoader)
"""DataLoader configuration"""

def __post_init__(self):
assert (
self.steps > 0 or self.steps == -1
Expand Down
16 changes: 10 additions & 6 deletions torchtitan/experiments/vlm/datasets/mm_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,14 @@ def build_mm_dataloader(
"""Build a data loader for multimodal datasets.

Args:
dp_world_size: Data parallel world size
dp_rank: Data parallel rank
tokenizer: Tokenizer for text processing
job_config: Job configuration
infinite: Whether to loop infinitely
dp_world_size: Data parallel world size.
dp_rank: Data parallel rank.
tokenizer: Tokenizer for text processing.
job_config: Job configuration containing dataset and DataLoader settings.
infinite: Whether to loop infinitely.

Returns:
DataLoader with appropriate parallelism handling
DataLoader with appropriate parallelism handling.
"""
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.local_batch_size
Expand Down Expand Up @@ -435,6 +435,10 @@ def build_mm_dataloader(
dp_world_size=dp_world_size,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=job_config.training.dataloader.num_workers,
persistent_workers=job_config.training.dataloader.persistent_workers,
prefetch_factor=job_config.training.dataloader.prefetch_factor,
pin_memory=job_config.training.dataloader.pin_memory,
)

return base_dataloader
28 changes: 26 additions & 2 deletions torchtitan/hf_datasets/text_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,15 @@ def build_text_dataloader(
job_config: JobConfig,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
"""Build a data loader for HuggingFace datasets.

Args:
dp_world_size: Data parallelism world size.
dp_rank: Data parallelism rank.
tokenizer: Tokenizer to use for encoding text.
job_config: Job configuration containing dataset and DataLoader settings.
infinite: Whether to loop the dataset infinitely.
"""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.local_batch_size
Expand All @@ -193,6 +201,10 @@ def build_text_dataloader(
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
num_workers=job_config.training.dataloader.num_workers,
persistent_workers=job_config.training.dataloader.persistent_workers,
prefetch_factor=job_config.training.dataloader.prefetch_factor,
pin_memory=job_config.training.dataloader.pin_memory,
)


Expand All @@ -203,7 +215,15 @@ def build_text_validation_dataloader(
job_config: JobConfig,
infinite: bool = False,
) -> ParallelAwareDataloader:
"""Build a validation data loader for HuggingFace datasets."""
"""Build a validation data loader for HuggingFace datasets.

Args:
dp_world_size: Data parallelism world size.
dp_rank: Data parallelism rank.
tokenizer: Tokenizer to use for encoding text.
job_config: Job configuration containing dataset and DataLoader settings.
infinite: Whether to loop the dataset infinitely.
"""
dataset_name = job_config.validation.dataset
dataset_path = job_config.validation.dataset_path
batch_size = job_config.validation.local_batch_size
Expand All @@ -224,4 +244,8 @@ def build_text_validation_dataloader(
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
num_workers=job_config.validation.dataloader.num_workers,
persistent_workers=job_config.validation.dataloader.persistent_workers,
prefetch_factor=job_config.validation.dataloader.prefetch_factor,
pin_memory=job_config.validation.dataloader.pin_memory,
)
29 changes: 27 additions & 2 deletions torchtitan/models/flux/flux_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,15 @@ def build_flux_dataloader(
tokenizer: FluxTokenizer | None,
infinite: bool = True,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
"""Build a data loader for HuggingFace datasets.

Args:
dp_world_size: Data parallelism world size.
dp_rank: Data parallelism rank.
job_config: Job configuration containing dataset and DataLoader settings.
tokenizer: Tokenizer (kept for compatibility, not used).
infinite: Whether to loop the dataset infinitely.
"""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.local_batch_size
Expand All @@ -337,6 +345,10 @@ def build_flux_dataloader(
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
num_workers=job_config.training.dataloader.num_workers,
persistent_workers=job_config.training.dataloader.persistent_workers,
prefetch_factor=job_config.training.dataloader.prefetch_factor,
pin_memory=job_config.training.dataloader.pin_memory,
)


Expand Down Expand Up @@ -400,7 +412,16 @@ def build_flux_validation_dataloader(
generate_timestamps: bool = True,
infinite: bool = False,
) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
"""Build a validation data loader for HuggingFace datasets.

Args:
dp_world_size: Data parallelism world size.
dp_rank: Data parallelism rank.
job_config: Job configuration containing dataset and DataLoader settings.
tokenizer: Tokenizer (kept for compatibility, not used).
generate_timestamps: Whether to generate timesteps for validation.
infinite: Whether to loop the dataset infinitely.
"""
dataset_name = job_config.validation.dataset
dataset_path = job_config.validation.dataset_path
batch_size = job_config.validation.local_batch_size
Expand All @@ -424,4 +445,8 @@ def build_flux_validation_dataloader(
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
num_workers=job_config.validation.dataloader.num_workers,
persistent_workers=job_config.validation.dataloader.persistent_workers,
prefetch_factor=job_config.validation.dataloader.prefetch_factor,
pin_memory=job_config.validation.dataloader.pin_memory,
)
Loading