Skip to content

Commit e755f2b

Browse files
committed
Add big dataset examples
1 parent 6c67c64 commit e755f2b

File tree

9 files changed

+691
-1
lines changed

9 files changed

+691
-1
lines changed

docs/Minimal_examples.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
66
.. include:: examples/frameworks/index.rst
77
.. include:: examples/distributed/index.rst
8+
.. include:: examples/data/index.rst

docs/examples/data/index.rst

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*****************************
2+
Data Handling during Training
3+
*****************************
4+
5+
6+
.. include:: examples/data/torchvision/_index.rst
+354
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
Torchvision
2+
===========
3+
4+
5+
**Prerequisites**
6+
7+
Make sure to read the following sections of the documentation before using this
8+
example:
9+
10+
* :ref:`pytorch_setup`
11+
* :ref:`001 - Single GPU Job`
12+
13+
The full source code for this example is available on `the mila-docs GitHub
14+
repository.
15+
<https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/torchvision>`_
16+
17+
18+
**job.sh**
19+
20+
.. code:: diff
21+
22+
# distributed/001_single_gpu/job.sh -> data/torchvision/job.sh
23+
#!/bin/bash
24+
#SBATCH --gpus-per-task=rtx8000:1
25+
#SBATCH --cpus-per-task=4
26+
#SBATCH --ntasks-per-node=1
27+
#SBATCH --mem=16G
28+
-#SBATCH --time=00:15:00
29+
+#SBATCH --time=01:30:00
30+
+set -o errexit
31+
32+
33+
# Echo time and hostname into log
34+
echo "Date: $(date)"
35+
echo "Hostname: $(hostname)"
36+
37+
38+
# Ensure only anaconda/3 module loaded.
39+
module --quiet purge
40+
# This example uses Conda to manage package dependencies.
41+
# See https://docs.mila.quebec/Userguide.html#conda for more information.
42+
module load anaconda/3
43+
module load cuda/11.7
44+
45+
# Creating the environment for the first time:
46+
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
47+
-# pytorch-cuda=11.7 -c pytorch -c nvidia
48+
+# pytorch-cuda=11.7 scipy -c pytorch -c nvidia
49+
# Other conda packages:
50+
# conda install -y -n pytorch -c conda-forge rich tqdm
51+
52+
# Activate pre-existing environment.
53+
conda activate pytorch
54+
55+
56+
-# Stage dataset into $SLURM_TMPDIR
57+
-mkdir -p $SLURM_TMPDIR/data
58+
-cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
59+
-# General-purpose alternatives combining copy and unpack:
60+
-# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
61+
-# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
62+
+# Prepare data for training
63+
+mkdir -p "$SLURM_TMPDIR/data"
64+
+
65+
+if [[ -z "${_DATA_PREP_WORKERS}" ]]
66+
+then
67+
+ _DATA_PREP_WORKERS=${SLURM_JOB_CPUS_PER_NODE}
68+
+fi
69+
+if [[ -z "${_DATA_PREP_WORKERS}" ]]
70+
+then
71+
+ _DATA_PREP_WORKERS=16
72+
+fi
73+
+
74+
+# Copy the dataset to $SLURM_TMPDIR so it is close to the GPUs for
75+
+# faster training
76+
+srun --ntasks=$SLURM_JOB_NUM_NODES --ntasks-per-node=1 \
77+
+ time -p bash data.sh "/network/datasets/inat" "$SLURM_TMPDIR/data" ${_DATA_PREP_WORKERS}
78+
79+
80+
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
81+
unset CUDA_VISIBLE_DEVICES
82+
83+
# Execute Python script
84+
python main.py
85+
86+
87+
**main.py**
88+
89+
.. code:: diff
90+
91+
# distributed/001_single_gpu/main.py -> data/torchvision/main.py
92+
-"""Single-GPU training example."""
93+
+"""Torchvision training example."""
94+
import logging
95+
import os
96+
-from pathlib import Path
97+
98+
import rich.logging
99+
import torch
100+
from torch import Tensor, nn
101+
from torch.nn import functional as F
102+
from torch.utils.data import DataLoader, random_split
103+
from torchvision import transforms
104+
-from torchvision.datasets import CIFAR10
105+
+from torchvision.datasets import INaturalist
106+
from torchvision.models import resnet18
107+
from tqdm import tqdm
108+
109+
110+
def main():
111+
- training_epochs = 10
112+
+ training_epochs = 1
113+
learning_rate = 5e-4
114+
weight_decay = 1e-4
115+
- batch_size = 128
116+
+ batch_size = 256
117+
118+
# Check that the GPU is available
119+
assert torch.cuda.is_available() and torch.cuda.device_count() > 0
120+
device = torch.device("cuda", 0)
121+
122+
# Setup logging (optional, but much better than using print statements)
123+
logging.basicConfig(
124+
level=logging.INFO,
125+
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
126+
)
127+
128+
logger = logging.getLogger(__name__)
129+
130+
# Create a model and move it to the GPU.
131+
- model = resnet18(num_classes=10)
132+
+ model = resnet18(num_classes=10000)
133+
model.to(device=device)
134+
135+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
136+
137+
- # Setup CIFAR10
138+
+ # Setup ImageNet
139+
num_workers = get_num_workers()
140+
- dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
141+
- train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
142+
+ try:
143+
+ dataset_path = f"{os.environ['SLURM_TMPDIR']}/data"
144+
+ except KeyError:
145+
+ dataset_path = "../dataset"
146+
+ train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path)
147+
train_dataloader = DataLoader(
148+
train_dataset,
149+
batch_size=batch_size,
150+
num_workers=num_workers,
151+
shuffle=True,
152+
)
153+
valid_dataloader = DataLoader(
154+
valid_dataset,
155+
batch_size=batch_size,
156+
num_workers=num_workers,
157+
shuffle=False,
158+
)
159+
test_dataloader = DataLoader( # NOTE: Not used in this example.
160+
test_dataset,
161+
batch_size=batch_size,
162+
num_workers=num_workers,
163+
shuffle=False,
164+
)
165+
166+
# Checkout the "checkpointing and preemption" example for more info!
167+
logger.debug("Starting training from scratch.")
168+
169+
for epoch in range(training_epochs):
170+
logger.debug(f"Starting epoch {epoch}/{training_epochs}")
171+
172+
- # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
173+
+ # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
174+
model.train()
175+
176+
# NOTE: using a progress bar from tqdm because it's nicer than using `print`.
177+
progress_bar = tqdm(
178+
total=len(train_dataloader),
179+
desc=f"Train epoch {epoch}",
180+
)
181+
182+
# Training loop
183+
for batch in train_dataloader:
184+
# Move the batch to the GPU before we pass it to the model
185+
batch = tuple(item.to(device) for item in batch)
186+
x, y = batch
187+
188+
# Forward pass
189+
logits: Tensor = model(x)
190+
191+
loss = F.cross_entropy(logits, y)
192+
193+
optimizer.zero_grad()
194+
loss.backward()
195+
optimizer.step()
196+
197+
# Calculate some metrics:
198+
n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
199+
n_samples = y.shape[0]
200+
accuracy = n_correct_predictions / n_samples
201+
202+
logger.debug(f"Accuracy: {accuracy.item():.2%}")
203+
logger.debug(f"Average Loss: {loss.item()}")
204+
205+
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
206+
progress_bar.update(1)
207+
progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
208+
progress_bar.close()
209+
210+
val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
211+
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
212+
213+
print("Done!")
214+
215+
216+
@torch.no_grad()
217+
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
218+
model.eval()
219+
220+
total_loss = 0.0
221+
n_samples = 0
222+
correct_predictions = 0
223+
224+
for batch in dataloader:
225+
batch = tuple(item.to(device) for item in batch)
226+
x, y = batch
227+
228+
logits: Tensor = model(x)
229+
loss = F.cross_entropy(logits, y)
230+
231+
batch_n_samples = x.shape[0]
232+
batch_correct_predictions = logits.argmax(-1).eq(y).sum()
233+
234+
total_loss += loss.item()
235+
n_samples += batch_n_samples
236+
correct_predictions += batch_correct_predictions
237+
238+
accuracy = correct_predictions / n_samples
239+
return total_loss, accuracy
240+
241+
242+
def make_datasets(
243+
dataset_path: str,
244+
val_split: float = 0.1,
245+
val_split_seed: int = 42,
246+
):
247+
- """Returns the training, validation, and test splits for CIFAR10.
248+
+ """Returns the training, validation, and test splits for iNat.
249+
250+
NOTE: We don't use image transforms here for simplicity.
251+
Having different transformations for train and validation would complicate things a bit.
252+
Later examples will show how to do the train/val/test split properly when using transforms.
253+
"""
254+
- train_dataset = CIFAR10(
255+
- root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
256+
+ train_dataset = INaturalist(
257+
+ root=dataset_path,
258+
+ transform=transforms.Compose([
259+
+ transforms.Resize(256),
260+
+ transforms.CenterCrop(224),
261+
+ transforms.ToTensor(),
262+
+ ]),
263+
+ version="2021_train"
264+
)
265+
- test_dataset = CIFAR10(
266+
- root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
267+
+ test_dataset = INaturalist(
268+
+ root=dataset_path,
269+
+ transform=transforms.Compose([
270+
+ transforms.Resize(256),
271+
+ transforms.CenterCrop(224),
272+
+ transforms.ToTensor(),
273+
+ ]),
274+
+ version="2021_valid"
275+
)
276+
# Split the training dataset into a training and validation set.
277+
- n_samples = len(train_dataset)
278+
- n_valid = int(val_split * n_samples)
279+
- n_train = n_samples - n_valid
280+
train_dataset, valid_dataset = random_split(
281+
- train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
282+
+ train_dataset, ((1 - val_split), val_split), torch.Generator().manual_seed(val_split_seed)
283+
)
284+
return train_dataset, valid_dataset, test_dataset
285+
286+
287+
def get_num_workers() -> int:
288+
"""Gets the optimal number of DatLoader workers to use in the current job."""
289+
if "SLURM_CPUS_PER_TASK" in os.environ:
290+
return int(os.environ["SLURM_CPUS_PER_TASK"])
291+
if hasattr(os, "sched_getaffinity"):
292+
return len(os.sched_getaffinity(0))
293+
return torch.multiprocessing.cpu_count()
294+
295+
296+
if __name__ == "__main__":
297+
main()
298+
299+
300+
**data.sh**
301+
302+
.. code:: bash
303+
304+
#!/bin/bash
305+
set -o errexit
306+
307+
_SRC=$1
308+
_DEST=$2
309+
_WORKERS=$3
310+
311+
# Clone the dataset structure locally and reorganise the raw files if needed
312+
(cd "${_SRC}" && find -L * -type f) | while read f
313+
do
314+
mkdir --parents "${_DEST}/$(dirname "$f")"
315+
# echo source first so it is matched to the ln's '-T' argument
316+
readlink --canonicalize "${_SRC}/$f"
317+
# echo output last so ln understands it's the output file
318+
echo "${_DEST}/$f"
319+
done | xargs -n2 -P${_WORKERS} ln --symbolic --force -T
320+
321+
(
322+
cd "${_DEST}"
323+
# Torchvision expects these names
324+
mv train.tar.gz 2021_train.tgz
325+
mv val.tar.gz 2021_valid.tgz
326+
)
327+
328+
# Extract and prepare the data
329+
python3 data.py "${_DEST}"
330+
331+
332+
**data.py**
333+
334+
.. code:: python
335+
336+
"""Make sure the data is available"""
337+
import sys
338+
import time
339+
340+
from torchvision.datasets import INaturalist
341+
342+
343+
t = -time.time()
344+
INaturalist(root=sys.argv[1], version="2021_train", download=True)
345+
INaturalist(root=sys.argv[1], version="2021_valid", download=True)
346+
t += time.time()
347+
print(f"Prepared data in {t/60:.2f}m")
348+
349+
350+
**Running this example**
351+
352+
.. code-block:: bash
353+
354+
$ sbatch job.sh

0 commit comments

Comments
 (0)