Skip to content

Commit

Permalink
add dataloader_idx logic directly in trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Oct 19, 2020
1 parent 4a946d0 commit 9bd19d3
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 28 deletions.
7 changes: 6 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def log(
self.trainer.callback_connector.validate_callback_logging_arguments(self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
Expand All @@ -265,7 +269,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down
12 changes: 10 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], add
m += f' {additional_err}'
assert x.grad_fn is not None, m

def add_dl_idx(self, name, dl_idx):
if dl_idx is not None:
name += f"/dataloader_idx_{dl_idx}"
return name

def log(
self,
name: str,
Expand All @@ -124,6 +129,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
current_dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -143,7 +149,7 @@ def log(
was_forked = True

# set step version
step_name = f'{name}_step'
step_name = self.add_dl_idx(f'{name}_step', current_dataloader_idx)
self.__set_meta(
step_name,
value,
Expand All @@ -159,7 +165,7 @@ def log(
self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'
epoch_name = self.add_dl_idx(f'{name}_epoch', current_dataloader_idx)
self.__set_meta(
epoch_name,
value,
Expand All @@ -174,6 +180,8 @@ def log(
)
self.__setitem__(epoch_name, value)

name = self.add_dl_idx(name, current_dataloader_idx)

# always log the original metric
self.__set_meta(
name,
Expand Down
25 changes: 12 additions & 13 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,6 @@ def _track_callback_metrics_1_0(self, logs, metrics_to_log=[], reduce_on_epoch=F
pbar_metrics = reduced_epoch_metrics.get_epoch_pbar_metrics()
forked_metrics = reduced_epoch_metrics.get_forked_metrics()

# make the keys 'k/dl'
logger_metrics = self.__rename_keys_by_dataloader_idx(logger_metrics, dl_idx, num_loaders)
pbar_metrics = self.__rename_keys_by_dataloader_idx(pbar_metrics, dl_idx, num_loaders)
forked_metrics = self.__rename_keys_by_dataloader_idx(forked_metrics, dl_idx, num_loaders)

# track the metrics
self.logged_metrics.update(logger_metrics)
self.add_progress_bar_metrics(pbar_metrics)
Expand All @@ -201,24 +196,28 @@ def _track_callback_metrics_1_0(self, logs, metrics_to_log=[], reduce_on_epoch=F
self.callback_metrics.update(forked_metrics)

# track the final results for the dataloader
self.eval_loop_results.append(deepcopy(self.callback_metrics))
self.add_to_eval_loop_results(dl_idx)

# actually log
if len(logger_metrics) > 0:
metrics_to_log.append(logger_metrics)

def add_to_eval_loop_results(self, dl_idx):
callback_metrics = deepcopy(self.callback_metrics)
for key in list(callback_metrics.keys()):
if "/dataloader_idx_" in key:
dl_idx_in_key = int(key.split("_")[-1])
# remove dl_idx from self.callback_metrics not belonging to this dataset.
if dl_idx_in_key != dl_idx:
del callback_metrics[key]
self.eval_loop_results.append(callback_metrics)


def log_on_evaluation_end(self, metrics_to_log):
metrics_to_log = dict(ChainMap(*metrics_to_log))
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {}, step=self.trainer.global_step)

def __rename_keys_by_dataloader_idx(self, metrics, dataloader_idx, num_loaders):
if num_loaders == 1:
return metrics

result = {f'{k}/dataloader_idx_{dataloader_idx}': v for k, v in metrics.items()}
return result

def _track_callback_metrics(self, eval_results, using_eval_result):
if (
len(eval_results) > 0 and
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _unset_dataloader_idx(self):
def set_dataloader_idx(self, dl_idx):
# reset the result of the PL module
model = self.trainer.get_model()
model._current_dataloader_idx = dl_idx if self.num_dataloaders > 1 else None
model._current_dataloader_idx = dl_idx if self.num_dataloaders > 1 else None

def on_evaluation_batch_start(self, *args, **kwargs):
# reset the result of the PL module
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
dataloader = self.accelerator_backend.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

# set dataloader idx inside model, so we can properly log
# set dataloader idx in pl_model, so we can handle multi dataloaders logging.
self.evaluation_loop.set_dataloader_idx(dataloader_idx)

for batch_idx, batch in enumerate(dataloader):
Expand Down
16 changes: 6 additions & 10 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,27 +589,21 @@ def on_epoch_start(self, trainer, pl_module):
def on_test_epoch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)

def on_batch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)

def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_start', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)

def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)
self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_end', 7, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)
self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)

# used to make sure aggregation works fine.
# we should obtain func[value * c for c in range(1, max_epochs * limit_test_batches)]) with func = np.mean if on_epoch else func = np.max
self.count += 1

def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices)
self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices)

def on_test_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices)
self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices)

max_epochs = 1
num_dataloaders = 2
Expand Down Expand Up @@ -644,6 +638,8 @@ def test_dataloader(self):
trainer.fit(model)
trainer.test()


breakpoint()
# Make sure the func_name exists within callback_metrics. If not, we missed some
callback_metrics_keys = [*trainer.callback_metrics.keys()]
for func_name in test_callback.callback_funcs_called.keys():
Expand Down

0 comments on commit 9bd19d3

Please sign in to comment.