-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds tests to make sure logging doesn't happen multiple times (#3899)
* Makes sure logging doesn't ever happen from non-root zero * Makes sure logging doesn't ever happen from non-root zero * Makes sure logging doesn't ever happen from non-root zero * added bug report model * fix local model * fix local model * fix local model * fix local model
- Loading branch information
1 parent
e4a56fa
commit 2cf17a3
Showing
2 changed files
with
61 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pytest | ||
import torch | ||
from tests.base import BoringModel | ||
import platform | ||
from distutils.version import LooseVersion | ||
from pytorch_lightning import Trainer, Callback | ||
from unittest import mock | ||
|
||
|
||
class TestModel(BoringModel): | ||
|
||
def on_pretrain_routine_end(self) -> None: | ||
with mock.patch('pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics') as m: | ||
self.trainer.logger_connector.log_metrics({'a': 2}, {}) | ||
logged_times = m.call_count | ||
expected = 1 if self.global_rank == 0 else 0 | ||
assert logged_times == expected, 'actual logger called from non-global zero' | ||
|
||
|
||
@pytest.mark.skipif(platform.system() == "Windows", | ||
reason="Distributed training is not supported on Windows") | ||
@pytest.mark.skipif((platform.system() == "Darwin" and | ||
LooseVersion(torch.__version__) < LooseVersion("1.3.0")), | ||
reason="Distributed training is not supported on MacOS before Torch 1.3.0") | ||
def test_global_zero_only_logging_ddp_cpu(tmpdir): | ||
""" | ||
Makes sure logging only happens from root zero | ||
""" | ||
model = TestModel() | ||
model.training_epoch_end = None | ||
trainer = Trainer( | ||
distributed_backend='ddp_cpu', | ||
num_processes=2, | ||
default_root_dir=tmpdir, | ||
limit_train_batches=1, | ||
limit_val_batches=1, | ||
max_epochs=1, | ||
weights_summary=None, | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") | ||
def test_global_zero_only_logging_ddp_spawn(tmpdir): | ||
""" | ||
Makes sure logging only happens from root zero | ||
""" | ||
model = TestModel() | ||
model.training_epoch_end = None | ||
trainer = Trainer( | ||
distributed_backend='ddp_spawn', | ||
gpus=2, | ||
default_root_dir=tmpdir, | ||
limit_train_batches=1, | ||
limit_val_batches=1, | ||
max_epochs=1, | ||
weights_summary=None, | ||
) | ||
trainer.fit(model) |