From 9b19705b6746446ab500850b42a7b24c1729ed94 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 1 Oct 2020 02:51:03 +0900 Subject: [PATCH] Fix exception chaining --- pytorch_lightning/accelerators/ddp2_backend.py | 2 +- pytorch_lightning/metrics/metric.py | 4 ++-- pytorch_lightning/trainer/logging.py | 4 ++-- pytorch_lightning/utilities/parsing.py | 4 ++-- tests/base/models.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 1b55a983c894e5..4d7b340ee380ca 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -49,7 +49,7 @@ def _resolve_task_idx(self): self.task_idx = int(os.environ['LOCAL_RANK']) except Exception as e: m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags' - raise MisconfigurationException(m) + raise MisconfigurationException(m) from e def train(self): model = self.trainer.model diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 45c50b084956f6..4ec58185a88610 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -165,7 +165,7 @@ def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: """ try: return torch.cat(tensors).mean(0) - except (ValueError, TypeError): + except (ValueError, TypeError) as e: if isinstance(tensors[0], Mapping): return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()} elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor): @@ -173,7 +173,7 @@ def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor: elif isinstance(tensors[0], torch.Tensor): return torch.stack(tensors).mean(0) else: - raise TypeError("unknown metric value format to aggregate") + raise TypeError("unknown metric value format to aggregate") from e @staticmethod def compute(self, data: Any, output: Any): diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index ff8aab3743759d..4185e1ac35cc93 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -141,13 +141,13 @@ def process_dict_result(self, output, train=False): if train: try: loss = output['loss'] - except Exception: + except Exception as e: if isinstance(output, torch.Tensor): loss = output else: raise RuntimeError( 'No `loss` value in the dictionary returned from `model.training_step()`.' - ) + ) from e # when using dp need to reduce the loss if self.use_dp or self.use_ddp2: diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index dab1127579b878..82fa6e865f242f 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -158,8 +158,8 @@ class AttributeDict(Dict): def __getattr__(self, key): try: return self[key] - except KeyError: - raise AttributeError(f'Missing attribute "{key}"') + except KeyError as e: + raise AttributeError(f'Missing attribute "{key}"') from e def __setattr__(self, key, val): self[key] = val diff --git a/tests/base/models.py b/tests/base/models.py index 9c319add4aca10..60dcf1777f1b22 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -10,9 +10,9 @@ try: from test_tube import HyperOptArgumentParser -except ImportError: +except ImportError as e: # TODO: this should be discussed and moved out of this package - raise ImportError('Missing test-tube package.') + raise ImportError('Missing test-tube package.') from e from pytorch_lightning.core.lightning import LightningModule