forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding test * adding test * added base parity model * added base parity model * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * move parity to benchmark * formatting * fixed gradient acc sched * move parity to benchmark * formatting * fixed gradient acc sched * skip for CPU * call last Co-authored-by: J. Borovec <[email protected]>
- Loading branch information
1 parent
097f8fe
commit 3d65e2e
Showing
4 changed files
with
153 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
import time | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
|
||
from pytorch_lightning import Trainer, LightningModule | ||
|
||
|
||
class ParityMNIST(LightningModule): | ||
|
||
def __init__(self): | ||
super(ParityMNIST, self).__init__() | ||
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) | ||
self.c_d1_bn = nn.BatchNorm1d(128) | ||
self.c_d1_drop = nn.Dropout(0.3) | ||
self.c_d2 = nn.Linear(in_features=128, out_features=10) | ||
|
||
def forward(self, x): | ||
x = x.view(x.size(0), -1) | ||
x = self.c_d1(x) | ||
x = torch.tanh(x) | ||
x = self.c_d1_bn(x) | ||
x = self.c_d1_drop(x) | ||
x = self.c_d2(x) | ||
return x | ||
|
||
def training_step(self, batch, batch_nb): | ||
x, y = batch | ||
y_hat = self(x) | ||
loss = F.cross_entropy(y_hat, y) | ||
return {'loss': loss} | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.Adam(self.parameters(), lr=0.02) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, | ||
transform=transforms.ToTensor()), batch_size=32) | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
def test_pytorch_parity(tmpdir): | ||
""" | ||
Verify that the same pytorch and lightning models achieve the same results | ||
:param tmpdir: | ||
:return: | ||
""" | ||
num_epochs = 2 | ||
num_rums = 3 | ||
lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs) | ||
manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs) | ||
|
||
# make sure the losses match exactly to 5 decimal places | ||
for pl_out, pt_out in zip(lightning_outs, manual_outs): | ||
np.testing.assert_almost_equal(pl_out, pt_out, 5) | ||
|
||
|
||
def set_seed(seed): | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed(seed) | ||
|
||
|
||
def vanilla_loop(MODEL, num_runs=10, num_epochs=10): | ||
""" | ||
Returns an array with the last loss from each epoch for each run | ||
""" | ||
device = torch.device('cuda' if torch.cuda.is_available() else "cpu") | ||
errors = [] | ||
times = [] | ||
|
||
for i in range(num_runs): | ||
time_start = time.perf_counter() | ||
|
||
# set seed | ||
seed = i | ||
set_seed(seed) | ||
|
||
# init model parts | ||
model = MODEL() | ||
dl = model.train_dataloader() | ||
optimizer = model.configure_optimizers() | ||
|
||
# model to GPU | ||
model = model.to(device) | ||
|
||
epoch_losses = [] | ||
for epoch in range(num_epochs): | ||
|
||
# run through full training set | ||
for j, batch in enumerate(dl): | ||
x, y = batch | ||
x = x.cuda(0) | ||
y = y.cuda(0) | ||
batch = (x, y) | ||
|
||
loss_dict = model.training_step(batch, j) | ||
loss = loss_dict['loss'] | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
# track last epoch loss | ||
epoch_losses.append(loss.item()) | ||
|
||
time_end = time.perf_counter() | ||
times.append(time_end - time_start) | ||
|
||
errors.append(epoch_losses[-1]) | ||
|
||
return errors, times | ||
|
||
|
||
def lightning_loop(MODEL, num_runs=10, num_epochs=10): | ||
errors = [] | ||
times = [] | ||
|
||
for i in range(num_runs): | ||
time_start = time.perf_counter() | ||
|
||
# set seed | ||
seed = i | ||
set_seed(seed) | ||
|
||
# init model parts | ||
model = MODEL() | ||
trainer = Trainer( | ||
max_epochs=num_epochs, | ||
show_progress_bar=False, | ||
weights_summary=None, | ||
gpus=1, | ||
early_stop_callback=False, | ||
checkpoint_callback=False | ||
) | ||
trainer.fit(model) | ||
|
||
final_loss = trainer.running_loss.last().item() | ||
errors.append(final_loss) | ||
|
||
time_end = time.perf_counter() | ||
times.append(time_end - time_start) | ||
|
||
return errors, times |