Skip to content

Commit

Permalink
Fix exception chaining
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Sep 30, 2020
1 parent 5ad5c56 commit 9b19705
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _resolve_task_idx(self):
self.task_idx = int(os.environ['LOCAL_RANK'])
except Exception as e:
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
raise MisconfigurationException(m)
raise MisconfigurationException(m) from e

def train(self):
model = self.trainer.model
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""
try:
return torch.cat(tensors).mean(0)
except (ValueError, TypeError):
except (ValueError, TypeError) as e:
if isinstance(tensors[0], Mapping):
return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()}
elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor):
return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)])
elif isinstance(tensors[0], torch.Tensor):
return torch.stack(tensors).mean(0)
else:
raise TypeError("unknown metric value format to aggregate")
raise TypeError("unknown metric value format to aggregate") from e

@staticmethod
def compute(self, data: Any, output: Any):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def process_dict_result(self, output, train=False):
if train:
try:
loss = output['loss']
except Exception:
except Exception as e:
if isinstance(output, torch.Tensor):
loss = output
else:
raise RuntimeError(
'No `loss` value in the dictionary returned from `model.training_step()`.'
)
) from e

# when using dp need to reduce the loss
if self.use_dp or self.use_ddp2:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class AttributeDict(Dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(f'Missing attribute "{key}"')
except KeyError as e:
raise AttributeError(f'Missing attribute "{key}"') from e

def __setattr__(self, key, val):
self[key] = val
Expand Down
4 changes: 2 additions & 2 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

try:
from test_tube import HyperOptArgumentParser
except ImportError:
except ImportError as e:
# TODO: this should be discussed and moved out of this package
raise ImportError('Missing test-tube package.')
raise ImportError('Missing test-tube package.') from e

from pytorch_lightning.core.lightning import LightningModule

Expand Down

0 comments on commit 9b19705

Please sign in to comment.