Skip to content

Commit

Permalink
Merge pull request #99 from grok-ai/develop
Browse files Browse the repository at this point in the history
0.4.0
  • Loading branch information
lucmos authored Oct 12, 2023
2 parents 8471d9a + 3851bfd commit 8ba02bb
Show file tree
Hide file tree
Showing 24 changed files with 354 additions and 223 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
CONDA_ENV_FILE: 'env.yaml'
CONDA_ENV_NAME: 'project-test'
COOKIECUTTER_PROJECT_NAME: 'project-test'
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Avoid writing boilerplate code to integrate:

- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
- [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
- [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
- [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
- [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
Expand Down
2 changes: 1 addition & 1 deletion cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
"repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
"python_version": "3.11",
"__version": "0.3.1"
"__version": "0.4.0"
}
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ and to avoid writing boilerplate code to integrate:

- [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
- [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
- [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
- [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
- [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
- [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
Expand Down
11 changes: 10 additions & 1 deletion {{ cookiecutter.repository_name }}/.env.template
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# While .env is a local file full of secrets, this can be public and ease the setup of known env variables.
# .env.template is a template for .env file that can be versioned.

# Set to 1 to show full stack trace on error, 0 to hide it
HYDRA_FULL_ERROR=1

# Configure where huggingface_hub will locally store data.
HF_HOME="~/.cache/huggingface"

# Configure the User Access Token to authenticate to the Hub
# HUGGING_FACE_HUB_TOKEN=
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ env:
CACHE_NUMBER: 0 # increase to reset cache manually
CONDA_ENV_FILE: './env.yaml'
CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'

{% raw %}
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:
strategy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ env:
CACHE_NUMBER: 1 # increase to reset cache manually
CONDA_ENV_FILE: './env.yaml'
CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'

{% raw %}
HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}

jobs:
build:

Expand Down
4 changes: 4 additions & 0 deletions {{ cookiecutter.repository_name }}/conf/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ core:
version: 0.0.1
tags: null

conventions:
x_key: 'x'
y_key: 'y'

defaults:
- hydra: default
- nn: default
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# This class defines which dataset to use,
# and also how to split in train/[val]/test.
_target_: {{ cookiecutter.package_name }}.utils.hf_io.load_hf_dataset
name: "mnist"
ref: "mnist"
train_split: train
# val_split: val
val_percentage: 0.1
test_split: test
label_key: label
data_key: image
num_classes: 10
input_shape: [1, 28, 28]
standard_x_key: ${conventions.x_key}
standard_y_key: ${conventions.y_key}
transforms:
_target_: {{ cookiecutter.package_name }}.utils.hf_io.HFTransform
key: ${conventions.x_key}
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
28 changes: 28 additions & 0 deletions {{ cookiecutter.repository_name }}/conf/nn/data/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule

val_images_fixed_idxs: [7371, 3963, 2861, 1701, 3172,
1749, 7023, 1606, 6481, 1377,
6003, 3593, 3410, 3399, 7277,
5337, 968, 8206, 288, 1968,
5677, 9156, 8139, 7660, 7089,
1893, 3845, 2084, 1944, 3375,
4848, 8704, 6038, 2183, 7422,
2682, 6878, 6127, 2941, 5823,
9129, 1798, 6477, 9264, 476,
3007, 4992, 1428, 9901, 5388]

accelerator: ${train.trainer.accelerator}

num_workers:
train: 4
val: 2
test: 0

batch_size:
train: 512
val: 128
test: 16

defaults:
- _self_
- dataset: vision/mnist # pick one of the yamls in nn/data/
54 changes: 15 additions & 39 deletions {{ cookiecutter.repository_name }}/conf/nn/default.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,23 @@
data:
_target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule

datasets:
train:
_target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

# val:
# - _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

test:
- _target_: {{ cookiecutter.package_name }}.data.dataset.MyDataset

accelerator: ${train.trainer.accelerator}

num_workers:
train: 8
val: 4
test: 4

batch_size:
train: 32
val: 16
test: 16

# example
val_percentage: 0.1
data: ???

module:
_target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule

optimizer:
# Adam-oriented deep learning
_target_: torch.optim.Adam
# These are all default parameters for the Adam optimizer
lr: 0.001
lr: 1e-3
betas: [ 0.9, 0.999 ]
eps: 1e-08
weight_decay: 0

lr_scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
T_0: 10
T_mult: 2
eta_min: 0 # min value for the lr
last_epoch: -1
verbose: False
# lr_scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
# T_0: 20
# T_mult: 1
# eta_min: 0
# last_epoch: -1
# verbose: False


defaults:
- _self_
- data: default
- module: default
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule
x_key: ${conventions.x_key}
y_key: ${conventions.y_key}

defaults:
- _self_
- model: cnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: {{ cookiecutter.package_name }}.modules.module.CNN
input_shape: ${nn.data.dataset.input_shape}
12 changes: 6 additions & 6 deletions {{ cookiecutter.repository_name }}/conf/train/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@ trainer:

restore:
ckpt_or_run_path: null
mode: continue # null, finetune, hotstart, continue
mode: null # null, finetune, hotstart, continue

monitor:
metric: 'loss/val'
mode: 'min'

callbacks:
- _target_: pytorch_lightning.callbacks.EarlyStopping
- _target_: lightning.pytorch.callbacks.EarlyStopping
patience: 42
verbose: False
monitor: ${train.monitor.metric}
mode: ${train.monitor.mode}

- _target_: pytorch_lightning.callbacks.ModelCheckpoint
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
save_top_k: 1
verbose: False
monitor: ${train.monitor.metric}
mode: ${train.monitor.mode}

- _target_: pytorch_lightning.callbacks.LearningRateMonitor
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: "step"
log_momentum: False

- _target_: pytorch_lightning.callbacks.progress.tqdm_progress.TQDMProgressBar
- _target_: lightning.pytorch.callbacks.progress.tqdm_progress.TQDMProgressBar
refresh_rate: 20

logging:
Expand All @@ -49,7 +49,7 @@ logging:
source: true

logger:
_target_: pytorch_lightning.loggers.WandbLogger
_target_: lightning.pytorch.loggers.WandbLogger

project: ${core.project_name}
entity: null
Expand Down
3 changes: 2 additions & 1 deletion {{ cookiecutter.repository_name }}/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ package_dir=
=src
packages=find:
install_requires =
nn-template-core==0.3.*
nn-template-core==0.4.*
anypy==0.0.*

# Add project specific dependencies
# Stuff easy to break with updates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# thus the logging configuration defined in the __init__.py must be called before
# the lightning import otherwise it has no effect.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503
lightning_logger = logging.getLogger("pytorch_lightning")
lightning_logger = logging.getLogger("lightning.pytorch")
# Remove all handlers associated with the lightning logger.
for handler in lightning_logger.handlers[:]:
lightning_logger.removeHandler(handler)
Expand Down
Loading

0 comments on commit 8ba02bb

Please sign in to comment.