Skip to content

Commit

Permalink
train, [val and test] with multi-dataloaders work
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Oct 19, 2020
1 parent 9bd19d3 commit 7bf4fcc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
8 changes: 7 additions & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], add
m += f' {additional_err}'
assert x.grad_fn is not None, m

def add_dl_idx(self, name, dl_idx):
def add_dl_idx(self, name: str, dl_idx: Union[None, int]) -> str:
"""
This function add dl_idx logic to logged key automatically if we have multiple dataloders
"""
if dl_idx is not None:
name += f"/dataloader_idx_{dl_idx}"
return name
Expand Down Expand Up @@ -149,6 +152,7 @@ def log(
was_forked = True

# set step version
# add possibly dataloader_idx
step_name = self.add_dl_idx(f'{name}_step', current_dataloader_idx)
self.__set_meta(
step_name,
Expand All @@ -165,6 +169,7 @@ def log(
self.__setitem__(step_name, value)

# set epoch version
# add possibly dataloader_idx
epoch_name = self.add_dl_idx(f'{name}_epoch', current_dataloader_idx)
self.__set_meta(
epoch_name,
Expand All @@ -180,6 +185,7 @@ def log(
)
self.__setitem__(epoch_name, value)

# add possibly dataloader_idx
name = self.add_dl_idx(name, current_dataloader_idx)

# always log the original metric
Expand Down
48 changes: 21 additions & 27 deletions tests/trainer/logging/test_eval_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,6 @@ def validation_step(self, batch, batch_idx):
loss = self.loss(batch, output)
self.log('val_loss', loss)

def val_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64))]

max_epochs = 1
model = TestModel()
model.validation_epoch_end = None
Expand Down Expand Up @@ -559,26 +556,27 @@ class TestCallback(callbacks.Callback):
funcs_attr = {}

def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]):
original_func_name = func_name
original_func_name = func_name[:]
for idx, t in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))):
# run logging
func_name = original_func_name
func_name = original_func_name[:]
on_step, on_epoch, prog_bar = t
custom_func_name = f"{func_idx}_{idx}_{func_name}"

pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, on_epoch=on_epoch, prog_bar=prog_bar)


num_dl_ext = ''
if pl_module._current_dataloader_idx is not None:
dl_idx = pl_module._current_dataloader_idx
custom_func_name += f"/dataloader_idx_{dl_idx}"
func_name += f"/dataloader_idx_{dl_idx}"
num_dl_ext = f"/dataloader_idx_{dl_idx}"
func_name += num_dl_ext

# catch information for verification
self.callback_funcs_called[func_name].append([self.count * func_idx])
self.funcs_attr[custom_func_name] = {"on_step":on_step, "on_epoch":on_epoch, "prog_bar":prog_bar, "is_created":False, "func_name":func_name}
self.funcs_attr[custom_func_name + num_dl_ext] = {"on_step":on_step, "on_epoch":on_epoch, "prog_bar":prog_bar, "is_created":False, "func_name":func_name}
if on_step and on_epoch:
self.funcs_attr[f"{custom_func_name}_step"] = {"on_step":True, "on_epoch":False, "prog_bar":prog_bar, "is_created":True, "func_name":func_name}
self.funcs_attr[f"{custom_func_name}_epoch"] = {"on_step":False, "on_epoch":True, "prog_bar":prog_bar, "is_created":True, "func_name":func_name}
self.funcs_attr[f"{custom_func_name}_step" + num_dl_ext] = {"on_step":True, "on_epoch":False, "prog_bar":prog_bar, "is_created":True, "func_name":func_name}
self.funcs_attr[f"{custom_func_name}_epoch" + num_dl_ext] = {"on_step":False, "on_epoch":True, "prog_bar":prog_bar, "is_created":True, "func_name":func_name}

def on_test_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices)
Expand Down Expand Up @@ -638,10 +636,9 @@ def test_dataloader(self):
trainer.fit(model)
trainer.test()


breakpoint()
# Make sure the func_name exists within callback_metrics. If not, we missed some
callback_metrics_keys = [*trainer.callback_metrics.keys()]

for func_name in test_callback.callback_funcs_called.keys():
is_in = False
for callback_metrics_key in callback_metrics_keys:
Expand Down Expand Up @@ -670,22 +667,19 @@ def get_expected_output(func_attr, original_values):
for func_name, output_value in trainer.callback_metrics.items():
if torch.is_tensor(output_value):
output_value = output_value.item()
# get creation attr
if func_name in test_callback.funcs_attr:
func_attr = test_callback.funcs_attr[func_name]

# retrived orginal logged values
original_values = test_callback.callback_funcs_called[func_attr["func_name"]]

# compute expected output and compare to actual one
expected_output = get_expected_output(func_attr, original_values)
try:
assert float(output_value) == float(expected_output)
except:
print(func_name, func_attr, original_values, output_value, expected_output)
# get func attr
func_attr = test_callback.funcs_attr[func_name]

# retrived orginal logged values
original_values = test_callback.callback_funcs_called[func_attr["func_name"]]

else:
print(func_name, output_value)
# compute expected output and compare to actual one
expected_output = get_expected_output(func_attr, original_values)
try:
assert float(output_value) == float(expected_output)
except:
print(func_name, func_attr, original_values, output_value, expected_output)

for func_name, func_attr in test_callback.funcs_attr.items():
if func_attr["prog_bar"] and (func_attr["on_step"] or func_attr["on_epoch"]):
Expand Down

0 comments on commit 7bf4fcc

Please sign in to comment.