From 49e143d549d294eb3c4221ae6274c15622346beb Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Nov 2020 11:01:59 +0000 Subject: [PATCH 1/9] wip --- .../test_trainer_steps_scalar_return.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index b5eae913ca428..10a7d45a2a42a 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -14,10 +14,13 @@ """ Tests to ensure that the training loop works with a scalar """ +import os import torch +import pytest from pytorch_lightning import Trainer from tests.base.deterministic_model import DeterministicModel +from tests.base import BoringModel def test_training_step_scalar(tmpdir): @@ -189,3 +192,40 @@ def test_train_step_epoch_end_scalar(tmpdir): opt_closure_result = trainer.train_loop.training_step_and_backward( batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) assert opt_closure_result['loss'].item() == 171 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_dpp_reduce_mean_(tmpdir): + class ExtentedModel(BoringModel): + + logged = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": torch.ones(1).to(self.device)} + + model = ExtentedModel() + model.training_step_end = None + model.training_epoch_end = None + + debug = False + if debug: + distributed_backend = None + gpus = 0 + else: + distributed_backend = "ddp" + gpus = 2 + + trainer = Trainer( + max_epochs=1, + default_root_dir=os.getcwd(), + limit_train_batches=10, + limit_test_batches=2, + limit_val_batches=2, + distributed_backend=distributed_backend, + gpus=gpus, + precision=32, + ) + + trainer.fit(model) From 8a43f356422a5dbd688856bd1d64b7f1cb0b1d31 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 5 Nov 2020 15:44:21 +0000 Subject: [PATCH 2/9] update --- .../test_trainer_steps_scalar_return.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index 10a7d45a2a42a..503a8f28485e0 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -195,7 +195,7 @@ def test_train_step_epoch_end_scalar(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_dpp_reduce_mean_(tmpdir): +def test_dpp_reduce_mean_pbar(tmpdir): class ExtentedModel(BoringModel): logged = [] @@ -209,22 +209,14 @@ def training_step(self, batch, batch_idx): model.training_step_end = None model.training_epoch_end = None - debug = False - if debug: - distributed_backend = None - gpus = 0 - else: - distributed_backend = "ddp" - gpus = 2 - trainer = Trainer( max_epochs=1, default_root_dir=os.getcwd(), limit_train_batches=10, limit_test_batches=2, limit_val_batches=2, - distributed_backend=distributed_backend, - gpus=gpus, + distributed_backend="ddp", + gpus=2, precision=32, ) From e566f81b6d3648f8ce768eec7d73f2eeab2bdd39 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 9 Nov 2020 15:05:14 +0000 Subject: [PATCH 3/9] normalize loss --- .../test_trainer_steps_scalar_return.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index 9e4916dc86cd0..624f2e29cd1ec 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -203,7 +203,8 @@ class ExtentedModel(BoringModel): def training_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) - return {"loss": torch.ones(1).to(self.device)} + loss /= loss.clone().detach() + return loss model = ExtentedModel() model.training_step_end = None From ddcbbc2b1a2c7c52953de400997c941fa7f4b862 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Nov 2020 15:24:31 +0000 Subject: [PATCH 4/9] update test --- .../test_trainer_steps_scalar_return.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index 624f2e29cd1ec..f3e8717dd9344 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -204,6 +204,7 @@ def training_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) loss /= loss.clone().detach() + self.log('self_log', loss, prog_bar=True, sync_dist=True) return loss model = ExtentedModel() From 1393a5e6331235beaf256072cf485f61fe3e2dc1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Nov 2020 16:21:12 +0000 Subject: [PATCH 5/9] resolve bug --- pytorch_lightning/core/step_result.py | 6 +++- .../test_trainer_steps_scalar_return.py | 34 ++++++++++++------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 8f8a517d544f0..b76924f049eab 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -129,7 +129,11 @@ def log( ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): - value = value.detach() + if sync_dist: + # make sure we don't make an in-place operation + value = value.clone().detach() + else: + value = value.detach() # sync across workers when using distributed training sync_fn = sync_fn or sync_ddp_if_available diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index f3e8717dd9344..80d7a616c07c4 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -194,20 +194,24 @@ def test_train_step_epoch_end_scalar(tmpdir): assert opt_closure_result['loss'].item() == 171 +class DPPReduceMeanPbarModel(BoringModel): + + logged = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + loss /= loss.clone().detach() + self.log('self_log', loss, prog_bar=True, sync_dist=True) + return {"loss": loss, "progress_bar":{"loss_2": loss}} + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_dpp_reduce_mean_pbar(tmpdir): - class ExtentedModel(BoringModel): - - logged = [] - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - loss /= loss.clone().detach() - self.log('self_log', loss, prog_bar=True, sync_dist=True) - return loss + os.environ['PL_DEV_DEBUG'] = '1' - model = ExtentedModel() + model = DPPReduceMeanPbarModel() model.training_step_end = None model.training_epoch_end = None @@ -219,7 +223,13 @@ def training_step(self, batch, batch_idx): limit_val_batches=2, distributed_backend="ddp", gpus=2, - precision=32, - ) + precision=32) trainer.fit(model) + + is_in = False + for pbar_metrics in trainer.dev_debugger.pbar_added_metrics: + if 'loss_2' in pbar_metrics: + is_in = True + assert pbar_metrics["loss_2"].item() == 1 + assert is_in is True From 8c66647af10b522bc1c75dbed886944b90453026 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Nov 2020 16:28:13 +0000 Subject: [PATCH 6/9] update test and add TODO --- .../test_trainer_steps_scalar_return.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py index 80d7a616c07c4..b85646e1c290f 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py +++ b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py @@ -208,28 +208,33 @@ def training_step(self, batch, batch_idx): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_dpp_reduce_mean_pbar(tmpdir): - os.environ['PL_DEV_DEBUG'] = '1' model = DPPReduceMeanPbarModel() model.training_step_end = None model.training_epoch_end = None + distributed_backend = "ddp_spawn" + trainer = Trainer( max_epochs=1, default_root_dir=os.getcwd(), limit_train_batches=10, limit_test_batches=2, limit_val_batches=2, - distributed_backend="ddp", + distributed_backend=distributed_backend, gpus=2, precision=32) trainer.fit(model) + # TODO: Move this test to DDP. pbar_added_metrics is empty with ddp_spawn for some reasons + + pbar_added_metrics = trainer.dev_debugger.pbar_added_metrics is_in = False - for pbar_metrics in trainer.dev_debugger.pbar_added_metrics: + for pbar_metrics in pbar_added_metrics: if 'loss_2' in pbar_metrics: is_in = True assert pbar_metrics["loss_2"].item() == 1 - assert is_in is True + if distributed_backend == "ddp": + assert is_in is True From b864a584c613d301be9665581dbbe7dd12adf4d9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Nov 2020 17:06:44 +0000 Subject: [PATCH 7/9] make sure it can be sync --- pytorch_lightning/core/step_result.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index b76924f049eab..9366a33c4146b 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -129,7 +129,8 @@ def log( ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): - if sync_dist: + is_dist = torch.distributed.is_available() and torch.distributed.is_initialized() + if sync_dist and is_dist: # make sure we don't make an in-place operation value = value.clone().detach() else: From ca3709b7bcb894f52490ff96d1ad07cf06781843 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 11 Nov 2020 17:36:31 +0000 Subject: [PATCH 8/9] add TODO --- pytorch_lightning/core/step_result.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 9366a33c4146b..a39ba0aea198a 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -129,10 +129,12 @@ def log( ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): + # TODO: Find a way to make the reduction only once, so we don't need to clone the is_dist = torch.distributed.is_available() and torch.distributed.is_initialized() if sync_dist and is_dist: # make sure we don't make an in-place operation - value = value.clone().detach() + # .detach().clone() and not .clone().detach() which will copy the entire graph. + value = value.detach().clone() else: value = value.detach() From c6483f9395943c05944c64eb1d765019d3923bb7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 12 Nov 2020 08:59:32 +0000 Subject: [PATCH 9/9] update sol --- pytorch_lightning/core/step_result.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a39ba0aea198a..12f1b57f836f2 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -129,18 +129,14 @@ def log( ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): - # TODO: Find a way to make the reduction only once, so we don't need to clone the - is_dist = torch.distributed.is_available() and torch.distributed.is_initialized() - if sync_dist and is_dist: - # make sure we don't make an in-place operation - # .detach().clone() and not .clone().detach() which will copy the entire graph. - value = value.detach().clone() - else: - value = value.detach() + value = value.detach() # sync across workers when using distributed training sync_fn = sync_fn or sync_ddp_if_available if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): + is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() + # TODO: Find a way to make the reduction only once, so we don't need to clone. + value = value.clone() if is_dist_initialized else value value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) if 'meta' not in self: