Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

attempting to remove some speed issues #1482

Merged
merged 7 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added speed parity tests (max 1 sec difference per epoch)([#1482](https://github.com/PyTorchLightning/pytorch-lightning/pull/1482))
- Added same step loggers' metrics aggregation ([#1278](https://github.com/PyTorchLightning/pytorch-lightning/pull/1278))
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
- Added parity test between a vanilla RNN model and lightning model ([#1351](https://github.com/PyTorchLightning/pytorch-lightning/pull/1351))
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/test_rnn_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import tests.base.utils as tutils

from pytorch_lightning import Trainer, LightningModule

Expand Down Expand Up @@ -64,6 +65,8 @@ def test_pytorch_parity(tmpdir):
for pl_out, pt_out in zip(lightning_outs, manual_outs):
np.testing.assert_almost_equal(pl_out, pt_out, 8)

tutils.assert_speed_parity(pl_times, pt_times, num_epochs)


def set_seed(seed):
np.random.seed(seed)
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/test_trainer_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import tests.base.utils as tutils

from pytorch_lightning import Trainer, LightningModule
from tests.base.datasets import TestingMNIST
Expand Down Expand Up @@ -64,6 +65,8 @@ def test_pytorch_parity(tmpdir):
for pl_out, pt_out in zip(lightning_outs, manual_outs):
np.testing.assert_almost_equal(pl_out, pt_out, 5)

tutils.assert_speed_parity(pl_times, pt_times, num_epochs)


def set_seed(seed):
np.random.seed(seed)
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ def process_output(self, output, train=False):
num_gpus = self.num_gpus
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)

for k, v in callback_metrics.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it’s done twice. this is the first time but it’s done again at the end of the method

if isinstance(v, torch.Tensor):
callback_metrics[k] = v.item()

# ---------------
# EXTRACT PROGRESS BAR KEYS
# ---------------
Expand Down
13 changes: 13 additions & 0 deletions tests/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
ROOT_PATH = os.path.abspath(os.path.dirname(__file__))


def assert_speed_parity(pl_times, pt_times, num_epochs):

# assert speeds
max_diff_per_epoch = 0.9
pl_times = np.asarray(pl_times)
pt_times = np.asarray(pt_times)
diffs = pl_times - pt_times
diffs = diffs / num_epochs

assert np.alltrue(diffs < max_diff_per_epoch), \
f"lightning was slower than PT (threshold {max_diff_per_epoch})"


def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
# save_dir = trainer_options['default_root_dir']

Expand Down