Skip to content

Commit

Permalink
ref: (results 1/n) enable tracking original metric when step and epoc…
Browse files Browse the repository at this point in the history
…h are both true (#3685)

* enable tracking original metric when step and epoch are both true
  • Loading branch information
williamFalcon authored Sep 28, 2020
1 parent 931995b commit ff2bab0
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 25 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def __init__(
self._has_setup_fit = False
self._has_setup_test = False

self.trainer = None

@property
def train_transforms(self):
"""
Expand Down
36 changes: 21 additions & 15 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,21 +161,22 @@ def log(
tbptt_pad_token=tbptt_pad_token,
)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(
name,
value,
prog_bar,
logger,
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
)

# set the value
self.__setitem__(name, value)
# always log the original metric
self.__set_meta(
name,
value,
prog_bar,
logger,
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
)

# set the value
self.__setitem__(name, value)

def __set_meta(
self,
Expand Down Expand Up @@ -378,12 +379,17 @@ def reduce_on_epoch_end(cls, outputs):
def reduce_across_time(cls, time_outputs):
# auto-reduce across time for tbptt
meta = time_outputs[0]['meta']

# in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this
if 'extra' in time_outputs[0]:
[x.pop('extra', None) for x in time_outputs]

result = cls()
result = recursive_gather(time_outputs, result)
recursive_stack(result)

for k, value in result.items():
if k == 'meta':
if k in ['meta', 'extra']:
continue

# pick the reduce fx
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/utilities/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from typing import Union


def is_overridden(method_name: str, model: LightningModule) -> bool:
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
Expand Down
6 changes: 6 additions & 0 deletions tests/base/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def training_epoch_end_return_for_log_epoch_and_step(self, result):
# only saw 4 batches
assert isinstance(result, TrainResult)

result.step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()
result.step_epoch_pbar_acc3 = result.step_step_epoch_pbar_acc3.prod()
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
result.minimize = result.minimize.mean()
result.checkpoint_on = result.checkpoint_on.mean()

result.step_step_epoch_log_and_pbar_acc1 = result.step_step_epoch_log_and_pbar_acc1.prod()
result.epoch_step_epoch_log_and_pbar_acc1 = result.epoch_step_epoch_log_and_pbar_acc1.prod()
result.step_step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()
Expand Down
8 changes: 8 additions & 0 deletions tests/base/model_train_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def eval_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None
result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True)

setattr(self, f'{eval_name}_step_called', True)

return result

def eval_step_end_full_loop_result_obj_dp(self, result):
Expand All @@ -150,10 +151,14 @@ def eval_step_end_full_loop_result_obj_dp(self, result):
reduced = getattr(result, f'epoch_{eval_name}_step_metric').mean()
setattr(result, f'epoch_{eval_name}_step_metric', reduced)

reduced = getattr(result, f'{eval_name}_step_metric').mean()
setattr(result, f'{eval_name}_step_metric', reduced)

result.checkpoint_on = result.checkpoint_on.mean()
result.early_stop_on = result.early_stop_on.mean()
result.log(f'{eval_name}_step_end_metric', torch.tensor(1).type_as(result.checkpoint_on))
setattr(self, f'{eval_name}_step_end_called', True)

return result

def eval_epoch_end_full_loop_result_obj_dp(self, result):
Expand All @@ -176,6 +181,9 @@ def eval_epoch_end_full_loop_result_obj_dp(self, result):
reduced = getattr(result, f'{eval_name}_step_end_metric').mean()
setattr(result, f'{eval_name}_step_end_metric', reduced)

reduced = getattr(result, f'{eval_name}_step_metric').mean()
setattr(result, f'{eval_name}_step_metric', reduced)

return result

def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None):
Expand Down
16 changes: 8 additions & 8 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert not model.training_step_end_called
assert not model.training_epoch_end_called

assert len(trainer.logger_connector.callback_metrics) == 8
assert len(trainer.logger_connector.callback_metrics) == 11

# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
Expand All @@ -227,7 +227,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_step_epoch_log_acc2'] == expected_val_2
assert 'step_epoch_pbar_acc3' not in logged_metrics
assert len(logged_metrics) == 4
assert len(logged_metrics) == 6

# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
Expand All @@ -236,7 +236,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['epoch_step_epoch_log_acc2'] == eval_2
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
assert len(logged_metrics) == 4
assert len(logged_metrics) == 6

# make sure we are using the correct metrics for callbacks
assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_step_epoch_pbar_acc3'] == expected_val_2
assert 'step_epoch_log_acc2' not in logged_metrics
assert len(logged_metrics) == 3
assert len(logged_metrics) == 5

# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
Expand All @@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['epoch_step_epoch_pbar_acc3'] == eval_2
assert 'step_epoch_log_acc2' not in epoch_end_metrics
assert len(logged_metrics) == 3
assert len(logged_metrics) == 5

# -----------------------------------------
# make sure training outputs what is expected
Expand All @@ -287,7 +287,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):

out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
assert len(out.batch_log_metrics) == 4

train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_training_step_epoch_end_result(tmpdir):
)
trainer.fit(model)

assert len(trainer.logger_connector.callback_metrics) == 11
assert len(trainer.logger_connector.callback_metrics) == 17

# make sure correct steps were called
assert model.training_step_called
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_training_step_epoch_end_result(tmpdir):

out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
assert len(out.batch_log_metrics) == 4

train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_validation_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_val_step_epoch_step_metrics(tmpdir):
)
trainer.fit(model)

assert len(trainer.logger_connector.callback_metrics) == 7
assert len(trainer.logger_connector.callback_metrics) == 11

# make sure correct steps were called
assert model.validation_step_called
Expand Down

0 comments on commit ff2bab0

Please sign in to comment.