Skip to content

Commit

Permalink
Document speed comparison (#2072)
Browse files Browse the repository at this point in the history
* docs

* script

* dump

* desc

* import

* import

* if

* norm

* t

* finished

* isort

* typing

Co-authored-by: Nicki Skafte <[email protected]>

* xlabel

* pandas

* time

Co-authored-by: Nicki Skafte <[email protected]>
  • Loading branch information
Borda and SkafteNicki authored Dec 17, 2020
1 parent 1b599ff commit b16441f
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 23 deletions.
17 changes: 17 additions & 0 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

BENCHMARK_ROOT = os.path.dirname(__file__)
PROJECT_ROOT = os.path.dirname(BENCHMARK_ROOT)
60 changes: 60 additions & 0 deletions benchmarks/generate_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import matplotlib.pylab as plt
import pandas as pd

from benchmarks.test_basic_parity import lightning_loop, vanilla_loop
from tests.base.models import ParityModuleMNIST, ParityModuleRNN

NUM_EPOCHS = 20
NUM_RUNS = 50
MODEL_CLASSES = (ParityModuleRNN, ParityModuleMNIST)
PATH_HERE = os.path.dirname(__file__)
FIGURE_EXTENSION = '.png'


def _main():
fig, axarr = plt.subplots(nrows=len(MODEL_CLASSES))

for i, cls_model in enumerate(MODEL_CLASSES):
path_csv = os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv')
if os.path.isfile(path_csv):
df_time = pd.read_csv(path_csv, index_col=0)
else:
vanilla = vanilla_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
lightning = lightning_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)

df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]})
df_time /= NUM_RUNS
df_time.to_csv(os.path.join(PATH_HERE, f'dump-times_{cls_model.__name__}.csv'))
# todo: add also relative X-axis ticks to see both: relative and absolute time differences
df_time.plot.hist(
ax=axarr[i],
bins=20,
alpha=0.5,
title=cls_model.__name__,
legend=True,
grid=True,
)
axarr[i].set(xlabel='time [seconds]')

path_fig = os.path.join(PATH_HERE, f'figure-parity-times{FIGURE_EXTENSION}')
fig.tight_layout()
fig.savefig(path_fig)


if __name__ == '__main__':
_main()
62 changes: 41 additions & 21 deletions benchmarks/test_parity.py → benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import numpy as np
import pytest
import torch
from tqdm import tqdm

from pytorch_lightning import seed_everything, Trainer
import tests.base.develop_utils as tutils
Expand All @@ -15,34 +30,33 @@
(ParityModuleMNIST, 0.25), # todo: lower this thr
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff):
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3):
"""
Verify that the same pytorch and lightning models achieve the same results
"""
num_epochs = 4
num_rums = 3
lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs)
manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs)
lightning = lightning_loop(cls_model, num_runs, num_epochs)
vanilla = vanilla_loop(cls_model, num_runs, num_epochs)

# make sure the losses match exactly to 5 decimal places
for pl_out, pt_out in zip(lightning_outs, manual_outs):
for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
np.testing.assert_almost_equal(pl_out, pt_out, 5)

# the fist run initialize dataset (download & filter)
tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:],
nb_epochs=num_epochs, max_diff=max_diff)
tutils.assert_speed_parity_absolute(
lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff
)


def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
"""
Returns an array with the last loss from each epoch for each run
"""
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
errors = []
times = []
hist_losses = []
hist_durations = []

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic = True
for i in range(num_runs):
for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'):
time_start = time.perf_counter()

# set seed
Expand Down Expand Up @@ -74,18 +88,21 @@ def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
epoch_losses.append(loss.item())

time_end = time.perf_counter()
times.append(time_end - time_start)
hist_durations.append(time_end - time_start)

errors.append(epoch_losses[-1])
hist_losses.append(epoch_losses[-1])

return errors, times
return {
'losses': hist_losses,
'durations': hist_durations,
}


def lightning_loop(cls_model, num_runs=10, num_epochs=10):
errors = []
times = []
hist_losses = []
hist_durations = []

for i in range(num_runs):
for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'):
time_start = time.perf_counter()

# set seed
Expand All @@ -108,9 +125,12 @@ def lightning_loop(cls_model, num_runs=10, num_epochs=10):
trainer.fit(model)

final_loss = trainer.train_loop.running_loss.last().item()
errors.append(final_loss)
hist_losses.append(final_loss)

time_end = time.perf_counter()
times.append(time_end - time_start)
hist_durations.append(time_end - time_start)

return errors, times
return {
'losses': hist_losses,
'durations': hist_durations,
}
14 changes: 14 additions & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import platform
import time
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions docs/source/benchmarking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Benchmark with vanilla PyTorch
==============================

In this section we set grounds for comparison between vanilla PyTorch and PT Lightning for most common scenarios.

Time comparison
---------------

We have set regular benchmarking against PyTorch vanilla training loop on with RNN and simple MNIST classifier as per of out CI.
In average for simple MNIST CNN classifier we are only about 0.06s slower per epoch, see detail chart bellow.

.. figure:: _images/benchmarks/figure-parity-times.png
:alt: Speed parity to vanilla PT, created on 2020-12-16
:width: 500
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ PyTorch Lightning Documentation
style_guide
performance
Lightning project template<https://github.com/PyTorchLightning/pytorch-lightning-conference-seed>
benchmarking


.. toctree::
Expand Down
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ pre-commit>=1.0

cloudpickle>=1.3
nltk>=3.3
pandas # needed in benchmarks
9 changes: 7 additions & 2 deletions tests/base/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ class MNIST(Dataset):
TEST_FILE_NAME = 'test.pt'
cache_folder_name = 'complete'

def __init__(self, root: str = PATH_DATASETS, train: bool = True,
normalize: tuple = (0.5, 1.0), download: bool = True):
def __init__(
self,
root: str = PATH_DATASETS,
train: bool = True,
normalize: tuple = (0.5, 1.0),
download: bool = True,
):
super().__init__()
self.root = root
self.train = train # training set or test set
Expand Down

0 comments on commit b16441f

Please sign in to comment.