You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
RFC: Unified ParticleDataset for Sample and Trajectory-based Data Loading
Summary
This RFC proposes a unified ParticleDataset class that supports both sample-based and trajectory-based data loading for particle simulation data. The implementation leverages PyTorch's Dataset class and provides a flexible interface for handling different data formats and loading modes, including support for distributed training.
Motivation
Current particle simulation data loading methods often require separate implementations for sample-based and trajectory-based approaches. This leads to code duplication and potential inconsistencies. By unifying these approaches into a single class, we aim to:
Simplify the codebase and reduce duplication.
Provide a consistent interface for different data loading needs.
Improve flexibility in handling various data formats (npz and h5).
Enhance maintainability and extensibility of the data loading pipeline.
Support distributed training scenarios.
The expected outcome is a more robust and versatile data loading system that can easily adapt to different research and production needs in particle simulation projects, including distributed training environments.
Design Detail
The core of this proposal is the ParticleDataset class and associated functions in the particle_data_loader.py file. Here's a detailed breakdown of its design:
Data Loading:
defload_data(path):
"""Load data stored in npz or h5 format."""# Implementation for loading npz and h5 files
ParticleDataset Class:
classParticleDataset(Dataset):
def__init__(self, file_path, input_sequence_length=6, mode='sample'):
# Initialize datasetdef_preprocess_data(self):
# Preprocess data based on modedef__len__(self):
# Return length of datasetdef__getitem__(self, idx):
# Get item based on mode (sample or trajectory)def_get_sample(self, idx):
# Get a single sampledef_get_trajectory(self, idx):
# Get a full trajectorydefget_num_features(self):
# Return the number of features in the dataset
Collate Functions:
defcollate_fn_sample(batch):
# Collate function for sample modedefcollate_fn_trajectory(batch):
# Collate function for trajectory mode
Data Loader Creation:
defget_data_loader(file_path, mode='sample', input_sequence_length=6, batch_size=32, shuffle=True, is_distributed=False):
# Create and return appropriate DataLoader based on mode and distributed setting
Usage in training script:
# Determine if we're using distributed trainingis_distributed=device==torch.device("cuda") andworld_size>1# Load training datadl=pdl.get_data_loader(
file_path=f"{cfg.data.path}train.npz",
mode='sample',
input_sequence_length=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
is_distributed=is_distributed
)
# Get the number of featurestrain_dataset=pdl.ParticleDataset(f"{cfg.data.path}train.npz")
n_features=train_dataset.get_num_features()
# Similar process for validation data
Drawbacks
Increased complexity of a single class handling multiple modes and distributed scenarios.
Potential for slightly increased memory usage due to storing both sample and trajectory-related attributes.
Users familiar with separate implementations might need to adapt to the new unified interface.
Rationale and Alternatives
This design is optimal because:
It provides a single, consistent interface for different data loading needs, including distributed training.
It leverages PyTorch's existing Dataset and DistributedSampler classes, ensuring compatibility with the PyTorch ecosystem.
It allows for easy switching between sample and trajectory modes without changing the underlying data structure.
It supports both distributed and non-distributed training scenarios with a single interface.
Alternatives considered:
Keeping separate classes for sample and trajectory loading. Rejected due to code duplication and lack of flexibility.
Using a factory pattern to create different dataset types. Rejected as overly complex for the current needs.
Having separate functions for distributed and non-distributed data loading. Rejected in favor of a unified interface with a flag for distributed training.
The impact of not implementing this change would be continued code duplication, potential inconsistencies between sample and trajectory implementations, and reduced flexibility in data loading options, especially in distributed training scenarios.
Prior Art
PyTorch's Dataset and DataLoader classes: Our implementation builds directly on these established foundations.
PyTorch's DistributedSampler: Used for handling data distribution in multi-GPU training.
TensorFlow's tf.data API: Provides similar flexibility in data loading, though with a different approach.
NVIDIA's DALI library: Offers high-performance data loading pipelines, though more complex than our needs.
Unresolved questions
How to handle very large datasets that don't fit in memory?
Should we implement lazy loading of data?
How can we optimize the performance of data loading for large-scale simulations?
Are there any specific optimizations needed for distributed training with very large datasets?
Changelog
Initial draft of the RFC.
Updated to reflect the current implementation, including support for distributed training and the unified get_data_loader function.
The text was updated successfully, but these errors were encountered:
RFC: Unified ParticleDataset for Sample and Trajectory-based Data Loading
Summary
This RFC proposes a unified
ParticleDataset
class that supports both sample-based and trajectory-based data loading for particle simulation data. The implementation leverages PyTorch's Dataset class and provides a flexible interface for handling different data formats and loading modes, including support for distributed training.Motivation
Current particle simulation data loading methods often require separate implementations for sample-based and trajectory-based approaches. This leads to code duplication and potential inconsistencies. By unifying these approaches into a single class, we aim to:
The expected outcome is a more robust and versatile data loading system that can easily adapt to different research and production needs in particle simulation projects, including distributed training environments.
Design Detail
The core of this proposal is the
ParticleDataset
class and associated functions in theparticle_data_loader.py
file. Here's a detailed breakdown of its design:Usage in training script:
Drawbacks
Rationale and Alternatives
This design is optimal because:
Alternatives considered:
The impact of not implementing this change would be continued code duplication, potential inconsistencies between sample and trajectory implementations, and reduced flexibility in data loading options, especially in distributed training scenarios.
Prior Art
Unresolved questions
Changelog
The text was updated successfully, but these errors were encountered: