Skip to content

Commit b932b50

Browse files
committed
Add checkpointing example
1 parent 454a9af commit b932b50

File tree

6 files changed

+646
-0
lines changed

6 files changed

+646
-0
lines changed
+327
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
Checkpointing
2+
=============
3+
4+
5+
**Prerequisites**
6+
7+
Make sure to read the following sections of the documentation before using this
8+
example:
9+
10+
* :ref:`pytorch_setup`
11+
* :ref:`001 - Single GPU Job`
12+
13+
The full source code for this example is available on `the mila-docs GitHub
14+
repository.
15+
<https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/checkpointing>`_
16+
17+
18+
**job.sh**
19+
20+
.. code:: diff
21+
22+
# distributed/001_single_gpu/job.sh -> data/checkpointing/job.sh
23+
#!/bin/bash
24+
#SBATCH --gpus-per-task=rtx8000:1
25+
#SBATCH --cpus-per-task=4
26+
#SBATCH --ntasks-per-node=1
27+
#SBATCH --mem=16G
28+
#SBATCH --time=00:15:00
29+
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
30+
+ # min before its time ends to give it a chance for
31+
+ # better cleanup. If you cancel the job manually,
32+
+ # make sure that you specify the signal as TERM like
33+
+ # so scancel --signal=TERM <jobid>.
34+
+ # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/
35+
+
36+
+# trap the signal to the main BATCH script here.
37+
+sig_handler()
38+
+{
39+
+ echo "BATCH interrupted"
40+
+ wait # wait for all children, this is important!
41+
+}
42+
+
43+
+trap 'sig_handler' SIGINT SIGTERM SIGCONT
44+
45+
46+
# Echo time and hostname into log
47+
echo "Date: $(date)"
48+
echo "Hostname: $(hostname)"
49+
50+
51+
# Ensure only anaconda/3 module loaded.
52+
module --quiet purge
53+
# This example uses Conda to manage package dependencies.
54+
# See https://docs.mila.quebec/Userguide.html#conda for more information.
55+
module load anaconda/3
56+
module load cuda/11.7
57+
58+
+
59+
# Creating the environment for the first time:
60+
# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \
61+
-# pytorch-cuda=11.7 -c pytorch -c nvidia
62+
+# pytorch-cuda=11.7 scipy -c pytorch -c nvidia
63+
# Other conda packages:
64+
# conda install -y -n pytorch -c conda-forge rich tqdm
65+
66+
# Activate pre-existing environment.
67+
conda activate pytorch
68+
69+
70+
# Stage dataset into $SLURM_TMPDIR
71+
mkdir -p $SLURM_TMPDIR/data
72+
cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/
73+
# General-purpose alternatives combining copy and unpack:
74+
# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/
75+
# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/
76+
77+
78+
# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0
79+
unset CUDA_VISIBLE_DEVICES
80+
81+
# Execute Python script
82+
python main.py
83+
84+
85+
**main.py**
86+
87+
.. code:: diff
88+
89+
# distributed/001_single_gpu/main.py -> data/checkpointing/main.py
90+
"""Single-GPU training example."""
91+
import logging
92+
import os
93+
-from pathlib import Path
94+
+import shutil
95+
96+
import rich.logging
97+
import torch
98+
from torch import Tensor, nn
99+
from torch.nn import functional as F
100+
from torch.utils.data import DataLoader, random_split
101+
from torchvision import transforms
102+
from torchvision.datasets import CIFAR10
103+
from torchvision.models import resnet18
104+
from tqdm import tqdm
105+
106+
107+
+try:
108+
+ _CHECKPTS_DIR = f"{os.environ['SCRATCH']}/checkpoints"
109+
+except KeyError:
110+
+ _CHECKPTS_DIR = "../checkpoints"
111+
+
112+
+
113+
def main():
114+
training_epochs = 10
115+
learning_rate = 5e-4
116+
weight_decay = 1e-4
117+
batch_size = 128
118+
+ resume_file = f"{_CHECKPTS_DIR}/resnet18_cifar10/checkpoint.pth.tar"
119+
+ start_epoch = 0
120+
+ best_acc = 0
121+
122+
# Check that the GPU is available
123+
assert torch.cuda.is_available() and torch.cuda.device_count() > 0
124+
device = torch.device("cuda", 0)
125+
126+
# Setup logging (optional, but much better than using print statements)
127+
logging.basicConfig(
128+
level=logging.INFO,
129+
handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package.
130+
)
131+
132+
logger = logging.getLogger(__name__)
133+
134+
- # Create a model and move it to the GPU.
135+
+ # Create a model.
136+
model = resnet18(num_classes=10)
137+
+
138+
+ # Resume from a checkpoint
139+
+ if os.path.isfile(resume_file):
140+
+ logger.debug(f"=> loading checkpoint '{resume_file}'")
141+
+ # Map model to be loaded to gpu.
142+
+ checkpoint = torch.load(resume_file, map_location="cuda:0")
143+
+ start_epoch = checkpoint["epoch"]
144+
+ best_acc = checkpoint["best_acc"]
145+
+ # best_acc may be from a checkpoint from a different GPU
146+
+ best_acc = best_acc.to(device)
147+
+ model.load_state_dict(checkpoint["state_dict"])
148+
+ optimizer.load_state_dict(checkpoint["optimizer"])
149+
+ logger.debug(f"=> loaded checkpoint '{resume_file}' (epoch {checkpoint['epoch']})")
150+
+ else:
151+
+ logger.debug(f"=> no checkpoint found at '{resume_file}'")
152+
+
153+
+ # Move the model to the GPU.
154+
model.to(device=device)
155+
156+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
157+
158+
# Setup CIFAR10
159+
num_workers = get_num_workers()
160+
- dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data"
161+
- train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path))
162+
+ try:
163+
+ dataset_path = f"{os.environ['SLURM_TMPDIR']}/data"
164+
+ except KeyError:
165+
+ dataset_path = "../dataset"
166+
+ train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path)
167+
train_dataloader = DataLoader(
168+
train_dataset,
169+
batch_size=batch_size,
170+
num_workers=num_workers,
171+
shuffle=True,
172+
)
173+
valid_dataloader = DataLoader(
174+
valid_dataset,
175+
batch_size=batch_size,
176+
num_workers=num_workers,
177+
shuffle=False,
178+
)
179+
test_dataloader = DataLoader( # NOTE: Not used in this example.
180+
test_dataset,
181+
batch_size=batch_size,
182+
num_workers=num_workers,
183+
shuffle=False,
184+
)
185+
186+
# Checkout the "checkpointing and preemption" example for more info!
187+
logger.debug("Starting training from scratch.")
188+
189+
- for epoch in range(training_epochs):
190+
+ for epoch in range(start_epoch, training_epochs):
191+
logger.debug(f"Starting epoch {epoch}/{training_epochs}")
192+
193+
- # Set the model in training mode (important for e.g. BatchNorm and Dropout layers)
194+
+ # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers)
195+
model.train()
196+
197+
# NOTE: using a progress bar from tqdm because it's nicer than using `print`.
198+
progress_bar = tqdm(
199+
total=len(train_dataloader),
200+
desc=f"Train epoch {epoch}",
201+
)
202+
203+
# Training loop
204+
for batch in train_dataloader:
205+
# Move the batch to the GPU before we pass it to the model
206+
batch = tuple(item.to(device) for item in batch)
207+
x, y = batch
208+
209+
# Forward pass
210+
logits: Tensor = model(x)
211+
212+
loss = F.cross_entropy(logits, y)
213+
214+
optimizer.zero_grad()
215+
loss.backward()
216+
optimizer.step()
217+
218+
# Calculate some metrics:
219+
n_correct_predictions = logits.detach().argmax(-1).eq(y).sum()
220+
n_samples = y.shape[0]
221+
accuracy = n_correct_predictions / n_samples
222+
223+
logger.debug(f"Accuracy: {accuracy.item():.2%}")
224+
logger.debug(f"Average Loss: {loss.item()}")
225+
226+
# Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just)
227+
progress_bar.update(1)
228+
progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())
229+
progress_bar.close()
230+
231+
val_loss, val_accuracy = validation_loop(model, valid_dataloader, device)
232+
logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}")
233+
234+
+ # remember best acc and save checkpoint
235+
+ is_best = val_accuracy > best_acc
236+
+ best_acc = max(val_accuracy, best_acc)
237+
+
238+
+ save_checkpoint({
239+
+ "epoch": epoch + 1,
240+
+ "arch": "resnet18",
241+
+ "state_dict": model.state_dict(),
242+
+ "best_acc": best_acc,
243+
+ "optimizer": optimizer.state_dict(),
244+
+ }, is_best)
245+
+
246+
print("Done!")
247+
248+
249+
@torch.no_grad()
250+
def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device):
251+
model.eval()
252+
253+
total_loss = 0.0
254+
n_samples = 0
255+
correct_predictions = 0
256+
257+
for batch in dataloader:
258+
batch = tuple(item.to(device) for item in batch)
259+
x, y = batch
260+
261+
logits: Tensor = model(x)
262+
loss = F.cross_entropy(logits, y)
263+
264+
batch_n_samples = x.shape[0]
265+
batch_correct_predictions = logits.argmax(-1).eq(y).sum()
266+
267+
total_loss += loss.item()
268+
n_samples += batch_n_samples
269+
correct_predictions += batch_correct_predictions
270+
271+
accuracy = correct_predictions / n_samples
272+
return total_loss, accuracy
273+
274+
275+
def make_datasets(
276+
dataset_path: str,
277+
val_split: float = 0.1,
278+
val_split_seed: int = 42,
279+
):
280+
"""Returns the training, validation, and test splits for CIFAR10.
281+
282+
NOTE: We don't use image transforms here for simplicity.
283+
Having different transformations for train and validation would complicate things a bit.
284+
Later examples will show how to do the train/val/test split properly when using transforms.
285+
"""
286+
train_dataset = CIFAR10(
287+
root=dataset_path, transform=transforms.ToTensor(), download=True, train=True
288+
)
289+
test_dataset = CIFAR10(
290+
root=dataset_path, transform=transforms.ToTensor(), download=True, train=False
291+
)
292+
# Split the training dataset into a training and validation set.
293+
- n_samples = len(train_dataset)
294+
- n_valid = int(val_split * n_samples)
295+
- n_train = n_samples - n_valid
296+
train_dataset, valid_dataset = random_split(
297+
- train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed)
298+
+ train_dataset, ((1 - val_split), val_split), torch.Generator().manual_seed(val_split_seed)
299+
)
300+
return train_dataset, valid_dataset, test_dataset
301+
302+
303+
def get_num_workers() -> int:
304+
"""Gets the optimal number of DatLoader workers to use in the current job."""
305+
if "SLURM_CPUS_PER_TASK" in os.environ:
306+
return int(os.environ["SLURM_CPUS_PER_TASK"])
307+
if hasattr(os, "sched_getaffinity"):
308+
return len(os.sched_getaffinity(0))
309+
return torch.multiprocessing.cpu_count()
310+
311+
312+
+def save_checkpoint(state, is_best, filename=f"{_CHECKPTS_DIR}/checkpoint.pth.tar"):
313+
+ torch.save(state, filename)
314+
+ if is_best:
315+
+ _dir = os.path.dirname(filename)
316+
+ shutil.copyfile(filename, f"{_dir}/model_best.pth.tar")
317+
+
318+
+
319+
if __name__ == "__main__":
320+
main()
321+
322+
323+
**Running this example**
324+
325+
.. code-block:: bash
326+
327+
$ sbatch job.sh
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
Checkpointing
2+
=============
3+
4+
5+
**Prerequisites**
6+
7+
Make sure to read the following sections of the documentation before using this
8+
example:
9+
10+
* :ref:`pytorch_setup`
11+
* :ref:`001 - Single GPU Job`
12+
13+
The full source code for this example is available on `the mila-docs GitHub
14+
repository.
15+
<https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/data/checkpointing>`_
16+
17+
18+
**job.sh**
19+
20+
.. literalinclude:: examples/data/checkpointing/job.sh.diff
21+
:language: diff
22+
23+
24+
**main.py**
25+
26+
.. literalinclude:: examples/data/checkpointing/main.py.diff
27+
:language: diff
28+
29+
30+
**Running this example**
31+
32+
.. code-block:: bash
33+
34+
$ sbatch job.sh

0 commit comments

Comments
 (0)