Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Unified ParticleDataset for Sample and Trajectory-based Data Loading #82

Open
kks32 opened this issue Jun 26, 2024 · 0 comments
Open
Assignees
Labels
Priority: High Priority: High
Milestone

Comments

@kks32
Copy link
Contributor

kks32 commented Jun 26, 2024

RFC: Unified ParticleDataset for Sample and Trajectory-based Data Loading

Summary

This RFC proposes a unified ParticleDataset class that supports both sample-based and trajectory-based data loading for particle simulation data. The implementation leverages PyTorch's Dataset class and provides a flexible interface for handling different data formats and loading modes, including support for distributed training.

Motivation

Current particle simulation data loading methods often require separate implementations for sample-based and trajectory-based approaches. This leads to code duplication and potential inconsistencies. By unifying these approaches into a single class, we aim to:

  1. Simplify the codebase and reduce duplication.
  2. Provide a consistent interface for different data loading needs.
  3. Improve flexibility in handling various data formats (npz and h5).
  4. Enhance maintainability and extensibility of the data loading pipeline.
  5. Support distributed training scenarios.

The expected outcome is a more robust and versatile data loading system that can easily adapt to different research and production needs in particle simulation projects, including distributed training environments.

Design Detail

The core of this proposal is the ParticleDataset class and associated functions in the particle_data_loader.py file. Here's a detailed breakdown of its design:

  1. Data Loading:
def load_data(path):
    """Load data stored in npz or h5 format."""
    # Implementation for loading npz and h5 files
  1. ParticleDataset Class:
class ParticleDataset(Dataset):
    def __init__(self, file_path, input_sequence_length=6, mode='sample'):
        # Initialize dataset
    
    def _preprocess_data(self):
        # Preprocess data based on mode

    def __len__(self):
        # Return length of dataset

    def __getitem__(self, idx):
        # Get item based on mode (sample or trajectory)

    def _get_sample(self, idx):
        # Get a single sample

    def _get_trajectory(self, idx):
        # Get a full trajectory

    def get_num_features(self):
        # Return the number of features in the dataset
  1. Collate Functions:
def collate_fn_sample(batch):
    # Collate function for sample mode

def collate_fn_trajectory(batch):
    # Collate function for trajectory mode
  1. Data Loader Creation:
def get_data_loader(file_path, mode='sample', input_sequence_length=6, batch_size=32, shuffle=True, is_distributed=False):
    # Create and return appropriate DataLoader based on mode and distributed setting

Usage in training script:

# Determine if we're using distributed training
is_distributed = device == torch.device("cuda") and world_size > 1

# Load training data
dl = pdl.get_data_loader(
    file_path=f"{cfg.data.path}train.npz",
    mode='sample',
    input_sequence_length=cfg.data.input_sequence_length,
    batch_size=cfg.data.batch_size,
    is_distributed=is_distributed
)

# Get the number of features
train_dataset = pdl.ParticleDataset(f"{cfg.data.path}train.npz")
n_features = train_dataset.get_num_features()

# Similar process for validation data

Drawbacks

  1. Increased complexity of a single class handling multiple modes and distributed scenarios.
  2. Potential for slightly increased memory usage due to storing both sample and trajectory-related attributes.
  3. Users familiar with separate implementations might need to adapt to the new unified interface.

Rationale and Alternatives

This design is optimal because:

  1. It provides a single, consistent interface for different data loading needs, including distributed training.
  2. It leverages PyTorch's existing Dataset and DistributedSampler classes, ensuring compatibility with the PyTorch ecosystem.
  3. It allows for easy switching between sample and trajectory modes without changing the underlying data structure.
  4. It supports both distributed and non-distributed training scenarios with a single interface.

Alternatives considered:

  1. Keeping separate classes for sample and trajectory loading. Rejected due to code duplication and lack of flexibility.
  2. Using a factory pattern to create different dataset types. Rejected as overly complex for the current needs.
  3. Having separate functions for distributed and non-distributed data loading. Rejected in favor of a unified interface with a flag for distributed training.

The impact of not implementing this change would be continued code duplication, potential inconsistencies between sample and trajectory implementations, and reduced flexibility in data loading options, especially in distributed training scenarios.

Prior Art

  1. PyTorch's Dataset and DataLoader classes: Our implementation builds directly on these established foundations.
  2. PyTorch's DistributedSampler: Used for handling data distribution in multi-GPU training.
  3. TensorFlow's tf.data API: Provides similar flexibility in data loading, though with a different approach.
  4. NVIDIA's DALI library: Offers high-performance data loading pipelines, though more complex than our needs.

Unresolved questions

  1. How to handle very large datasets that don't fit in memory?
  2. Should we implement lazy loading of data?
  3. How can we optimize the performance of data loading for large-scale simulations?
  4. Are there any specific optimizations needed for distributed training with very large datasets?

Changelog

  • Initial draft of the RFC.
  • Updated to reflect the current implementation, including support for distributed training and the unified get_data_loader function.
@kks32 kks32 added the Priority: High Priority: High label Jun 26, 2024
@kks32 kks32 added this to the GNS v2.0 milestone Jun 26, 2024
@kks32 kks32 self-assigned this Jun 26, 2024
@kks32 kks32 mentioned this issue Jun 27, 2024
@geoelements geoelements deleted a comment from ramc77 Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: High Priority: High
Projects
Development

No branches or pull requests

1 participant