Skip to content

Commit

Permalink
Merge pull request #96 from grok-ai/feature/hf-integration
Browse files Browse the repository at this point in the history
Add hf datasets integration
  • Loading branch information
lucmos authored Oct 12, 2023
2 parents 306b461 + df7afb7 commit e22c1a3
Show file tree
Hide file tree
Showing 20 changed files with 339 additions and 208 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
CONDA_ENV_FILE: 'env.yaml'
CONDA_ENV_NAME: 'project-test'
COOKIECUTTER_PROJECT_NAME: 'project-test'
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Avoid writing boilerplate code to integrate:

- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
- [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
- [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
- [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
- [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and to avoid writing boilerplate code to integrate:

- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
- [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
- [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
- [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
- [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
Expand Down
11 changes: 10 additions & 1 deletion {{ cookiecutter.repository_name }}/.env.template
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# While .env is a local file full of secrets, this can be public and ease the setup of known env variables.
# .env.template is a template for .env file that can be versioned.

# Set to 1 to show full stack trace on error, 0 to hide it
HYDRA_FULL_ERROR=1

# Configure where huggingface_hub will locally store data.
HF_HOME="~/.cache/huggingface"

# Configure the User Access Token to authenticate to the Hub
# HUGGING_FACE_HUB_TOKEN=
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ env:
CACHE_NUMBER: 0 # increase to reset cache manually
CONDA_ENV_FILE: './env.yaml'
CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'

{% raw %}
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:
strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ env:
CACHE_NUMBER: 1 # increase to reset cache manually
CONDA_ENV_FILE: './env.yaml'
CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'

{% raw %}
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:

Expand Down
4 changes: 4 additions & 0 deletions {{ cookiecutter.repository_name }}/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ core:
version: 0.0.1
tags: null

conventions:
x_key: 'x'
y_key: 'y'

defaults:
- hydra: default
- nn: default
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# This class defines which dataset to use,
# and also how to split in train/[val]/test.
_target_: {{ cookiecutter.package_name }}.utils.hf_io.load_hf_dataset
name: "mnist"
ref: "mnist"
train_split: train
# val_split: val
val_percentage: 0.1
test_split: test
label_key: label
data_key: image
num_classes: 10
input_shape: [1, 28, 28]
standard_x_key: ${conventions.x_key}
standard_y_key: ${conventions.y_key}
transforms:
_target_: {{ cookiecutter.package_name }}.utils.hf_io.HFTransform
key: ${conventions.x_key}
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
28 changes: 28 additions & 0 deletions {{ cookiecutter.repository_name }}/conf/nn/data/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule

val_images_fixed_idxs: [7371, 3963, 2861, 1701, 3172,
1749, 7023, 1606, 6481, 1377,
6003, 3593, 3410, 3399, 7277,
5337, 968, 8206, 288, 1968,
5677, 9156, 8139, 7660, 7089,
1893, 3845, 2084, 1944, 3375,
4848, 8704, 6038, 2183, 7422,
2682, 6878, 6127, 2941, 5823,
9129, 1798, 6477, 9264, 476,
3007, 4992, 1428, 9901, 5388]

accelerator: ${train.trainer.accelerator}

num_workers:
train: 4
val: 2
test: 0

batch_size:
train: 512
val: 128
test: 16

defaults:
- _self_
- dataset: vision/mnist # pick one of the yamls in nn/data/
54 changes: 15 additions & 39 deletions {{ cookiecutter.repository_name }}/conf/nn/default.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,23 @@
data:
_target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule

datasets:
train:
_target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

# val:
# - _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

test:
- _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

accelerator: ${train.trainer.accelerator}

num_workers:
train: 8
val: 4
test: 4

batch_size:
train: 32
val: 16
test: 16

# example
val_percentage: 0.1
data: ???

module:
_target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule

optimizer:
# Adam-oriented deep learning
_target_: torch.optim.Adam
# These are all default parameters for the Adam optimizer
lr: 0.001
lr: 1e-3
betas: [ 0.9, 0.999 ]
eps: 1e-08
weight_decay: 0

lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
T_0: 10
T_mult: 2
eta_min: 0 # min value for the lr
last_epoch: -1
verbose: False
# lr_scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
# T_0: 20
# T_mult: 1
# eta_min: 0
# last_epoch: -1
# verbose: False


defaults:
- _self_
- data: default
- module: default
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule
x_key: ${conventions.x_key}
y_key: ${conventions.y_key}

defaults:
- _self_
- model: cnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: {{ cookiecutter.package_name }}.modules.module.CNN
input_shape: ${nn.data.dataset.input_shape}
2 changes: 1 addition & 1 deletion {{ cookiecutter.repository_name }}/conf/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

restore:
ckpt_or_run_path: null
mode: continue # null, finetune, hotstart, continue
mode: null # null, finetune, hotstart, continue

monitor:
metric: 'loss/val'
Expand Down
1 change: 1 addition & 0 deletions {{ cookiecutter.repository_name }}/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package_dir=
packages=find:
install_requires =
nn-template-core==0.3.*
anypy==0.0.*

# Add project specific dependencies
# Stuff easy to break with updates
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
from functools import cached_property, partial
from pathlib import Path
from typing import List, Mapping, Optional, Sequence
from typing import List, Mapping, Optional

import hydra
import omegaconf
import pytorch_lightning as pl
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from tqdm import tqdm

from nn_core.common import PROJECT_ROOT
from nn_core.nn_types import Split
Expand Down Expand Up @@ -80,6 +80,10 @@ def load(src_path: Path) -> "MetaData":
class_vocab=class_vocab,
)

def __repr__(self) -> str:
attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()])
return f"{self.__class__.__name__}(\n {attributes}\n)"


def collate_fn(samples: List, split: Split, metadata: MetaData):
"""Custom collate function for dataloaders with access to split and metadata.
Expand All @@ -98,26 +102,26 @@ def collate_fn(samples: List, split: Split, metadata: MetaData):
class MyDataModule(pl.LightningDataModule):
def __init__(
self,
datasets: DictConfig,
dataset: DictConfig,
num_workers: DictConfig,
batch_size: DictConfig,
accelerator: str,
# example
val_percentage: float,
val_images_fixed_idxs: List[int],
):
super().__init__()
self.datasets = datasets
self.dataset = dataset
self.num_workers = num_workers
self.batch_size = batch_size
# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus
self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu"

self.train_dataset: Optional[Dataset] = None
self.val_datasets: Optional[Sequence[Dataset]] = None
self.test_datasets: Optional[Sequence[Dataset]] = None
self.val_dataset: Optional[Dataset] = None
self.test_dataset: Optional[Dataset] = None

# example
self.val_percentage: float = val_percentage
self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs

@cached_property
def metadata(self) -> MetaData:
Expand All @@ -132,40 +136,25 @@ def metadata(self) -> MetaData:
if self.train_dataset is None:
self.setup(stage="fit")

return MetaData(class_vocab=self.train_dataset.dataset.class_vocab)
return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)})

def prepare_data(self) -> None:
# download only
pass

def setup(self, stage: Optional[str] = None):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Here you should instantiate your datasets, you may also split the train into train and validation if needed.
if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_datasets is None):
# example
mnist_train = hydra.utils.instantiate(
self.datasets.train,
split="train",
transform=transform,
path=PROJECT_ROOT / "data",
)
train_length = int(len(mnist_train) * (1 - self.val_percentage))
val_length = len(mnist_train) - train_length
self.train_dataset, val_dataset = random_split(mnist_train, [train_length, val_length])

self.val_datasets = [val_dataset]
self.transform = hydra.utils.instantiate(self.dataset.transforms)

self.hf_datasets = hydra.utils.instantiate(self.dataset)
self.hf_datasets.set_transform(self.transform)

# Here you should instantiate your dataset, you may also split the train into train and validation if needed.
if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None):
self.train_dataset = self.hf_datasets["train"]
self.val_dataset = self.hf_datasets["val"]

if stage is None or stage == "test":
self.test_datasets = [
hydra.utils.instantiate(
dataset_cfg,
split="test",
path=PROJECT_ROOT / "data",
transform=transform,
)
for dataset_cfg in self.datasets.test
]
self.test_dataset = self.hf_datasets["test"]

def train_dataloader(self) -> DataLoader:
return DataLoader(
Expand All @@ -177,34 +166,28 @@ def train_dataloader(self) -> DataLoader:
collate_fn=partial(collate_fn, split="train", metadata=self.metadata),
)

def val_dataloader(self) -> Sequence[DataLoader]:
return [
DataLoader(
dataset,
shuffle=False,
batch_size=self.batch_size.val,
num_workers=self.num_workers.val,
pin_memory=self.pin_memory,
collate_fn=partial(collate_fn, split="val", metadata=self.metadata),
)
for dataset in self.val_datasets
]

def test_dataloader(self) -> Sequence[DataLoader]:
return [
DataLoader(
dataset,
shuffle=False,
batch_size=self.batch_size.test,
num_workers=self.num_workers.test,
pin_memory=self.pin_memory,
collate_fn=partial(collate_fn, split="test", metadata=self.metadata),
)
for dataset in self.test_datasets
]
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
shuffle=False,
batch_size=self.batch_size.val,
num_workers=self.num_workers.val,
pin_memory=self.pin_memory,
collate_fn=partial(collate_fn, split="val", metadata=self.metadata),
)

def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
shuffle=False,
batch_size=self.batch_size.test,
num_workers=self.num_workers.test,
pin_memory=self.pin_memory,
collate_fn=partial(collate_fn, split="test", metadata=self.metadata),
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " f"{self.batch_size=})"
return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})"


@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
Expand All @@ -214,7 +197,12 @@ def main(cfg: omegaconf.DictConfig) -> None:
Args:
cfg: the hydra configuration
"""
_: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
m.metadata
m.setup()

for _ in tqdm(m.train_dataloader()):
pass


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e22c1a3

Please sign in to comment.