diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index acf860b706eb58..363d79d39c95cc 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -109,7 +109,10 @@ def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], add m += f' {additional_err}' assert x.grad_fn is not None, m - def add_dl_idx(self, name, dl_idx): + def add_dl_idx(self, name: str, dl_idx: Union[None, int]) -> str: + """ + This function add dl_idx logic to logged key automatically if we have multiple dataloders + """ if dl_idx is not None: name += f"/dataloader_idx_{dl_idx}" return name @@ -149,6 +152,7 @@ def log( was_forked = True # set step version + # add possibly dataloader_idx step_name = self.add_dl_idx(f'{name}_step', current_dataloader_idx) self.__set_meta( step_name, @@ -165,6 +169,7 @@ def log( self.__setitem__(step_name, value) # set epoch version + # add possibly dataloader_idx epoch_name = self.add_dl_idx(f'{name}_epoch', current_dataloader_idx) self.__set_meta( epoch_name, @@ -180,6 +185,7 @@ def log( ) self.__setitem__(epoch_name, value) + # add possibly dataloader_idx name = self.add_dl_idx(name, current_dataloader_idx) # always log the original metric diff --git a/tests/trainer/logging/test_eval_loop_logging_1_0.py b/tests/trainer/logging/test_eval_loop_logging_1_0.py index d70dfc00440e63..69e804177e8cbd 100644 --- a/tests/trainer/logging/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging/test_eval_loop_logging_1_0.py @@ -479,9 +479,6 @@ def validation_step(self, batch, batch_idx): loss = self.loss(batch, output) self.log('val_loss', loss) - def val_dataloader(self): - return [torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64))] - max_epochs = 1 model = TestModel() model.validation_epoch_end = None @@ -559,26 +556,27 @@ class TestCallback(callbacks.Callback): funcs_attr = {} def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): - original_func_name = func_name + original_func_name = func_name[:] for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): # run logging - func_name = original_func_name + func_name = original_func_name[:] on_step, on_epoch, prog_bar = t custom_func_name = f"{func_idx}_{idx}_{func_name}" pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar) - + + num_dl_ext = '' if pl_module._current_dataloader_idx is not None: dl_idx = pl_module._current_dataloader_idx - custom_func_name += f"/dataloader_idx_{dl_idx}" - func_name += f"/dataloader_idx_{dl_idx}" + num_dl_ext = f"/dataloader_idx_{dl_idx}" + func_name += num_dl_ext # catch information for verification self.callback_funcs_called[func_name].append([self.count * func_idx]) - self.funcs_attr[custom_func_name] = {"on_step":on_step, "on_epoch":on_epoch, "prog_bar":prog_bar, "is_created":False, "func_name":func_name} + self.funcs_attr[custom_func_name + num_dl_ext] = {"on_step":on_step, "on_epoch":on_epoch, "prog_bar":prog_bar, "is_created":False, "func_name":func_name} if on_step and on_epoch: - self.funcs_attr[f"{custom_func_name}_step"] = {"on_step":True, "on_epoch":False, "prog_bar":prog_bar, "is_created":True, "func_name":func_name} - self.funcs_attr[f"{custom_func_name}_epoch"] = {"on_step":False, "on_epoch":True, "prog_bar":prog_bar, "is_created":True, "func_name":func_name} + self.funcs_attr[f"{custom_func_name}_step" + num_dl_ext] = {"on_step":True, "on_epoch":False, "prog_bar":prog_bar, "is_created":True, "func_name":func_name} + self.funcs_attr[f"{custom_func_name}_epoch" + num_dl_ext] = {"on_step":False, "on_epoch":True, "prog_bar":prog_bar, "is_created":True, "func_name":func_name} def on_test_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -638,10 +636,9 @@ def test_dataloader(self): trainer.fit(model) trainer.test() - - breakpoint() # Make sure the func_name exists within callback_metrics. If not, we missed some callback_metrics_keys = [*trainer.callback_metrics.keys()] + for func_name in test_callback.callback_funcs_called.keys(): is_in = False for callback_metrics_key in callback_metrics_keys: @@ -670,22 +667,19 @@ def get_expected_output(func_attr, original_values): for func_name, output_value in trainer.callback_metrics.items(): if torch.is_tensor(output_value): output_value = output_value.item() - # get creation attr - if func_name in test_callback.funcs_attr: - func_attr = test_callback.funcs_attr[func_name] - # retrived orginal logged values - original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - - # compute expected output and compare to actual one - expected_output = get_expected_output(func_attr, original_values) - try: - assert float(output_value) == float(expected_output) - except: - print(func_name, func_attr, original_values, output_value, expected_output) + # get func attr + func_attr = test_callback.funcs_attr[func_name] + + # retrived orginal logged values + original_values = test_callback.callback_funcs_called[func_attr["func_name"]] - else: - print(func_name, output_value) + # compute expected output and compare to actual one + expected_output = get_expected_output(func_attr, original_values) + try: + assert float(output_value) == float(expected_output) + except: + print(func_name, func_attr, original_values, output_value, expected_output) for func_name, func_attr in test_callback.funcs_attr.items(): if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]):