Skip to content

Commit

Permalink
Additional test for logging during validation loop (#3907)
Browse files Browse the repository at this point in the history
* Added test for logging in validation step when using dict dataset with string value

* fix recursive issue

* fix recursive issue

Co-authored-by: Nathan Painchaud <[email protected]>
Co-authored-by: William Falcon <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2020
1 parent 064ae53 commit c510a7f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
13 changes: 11 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def __set_meta(
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)

def track_batch_size(self, batch):
batch_size = self.unpack_batch_size(batch)
try:
batch_size = self.unpack_batch_size(batch)
except RecursionError as re:
batch_size = 1

meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)

Expand Down Expand Up @@ -330,6 +334,8 @@ def unpack_batch_size(self, sample):
"""
if isinstance(sample, torch.Tensor):
size = sample.size(0)
elif isinstance(sample, str):
return len(sample)
elif isinstance(sample, dict):
sample = next(iter(sample.values()), 1)
size = self.unpack_batch_size(sample)
Expand Down Expand Up @@ -406,7 +412,10 @@ def reduce_on_epoch_end(cls, outputs):
if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
reduced_val = weighted_mean(result[k], batch_sizes)
try:
reduced_val = weighted_mean(result[k], batch_sizes)
except Exception as e:
reduced_val = torch.mean(result[k])
else:
reduced_val = fx(result[k])

Expand Down
12 changes: 12 additions & 0 deletions tests/base/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def __len__(self):
return self.len


class RandomDictStringDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return {"id": str(index), "x": self.data[index]}

def __len__(self):
return self.len


class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
Expand Down
34 changes: 33 additions & 1 deletion tests/trainer/logging/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Tests to ensure that the training loop works with a dict (1.0)
"""
from tests.base.boring_model import BoringModel, RandomDictDataset
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
import os
import torch
import pytest
Expand Down Expand Up @@ -387,3 +387,35 @@ def val_dataloader(self):
}

assert generated == expected


def test_validation_step_with_string_data_logging():
class TestModel(BoringModel):
def on_train_epoch_start(self) -> None:
print("override any method to prove your bug")

def training_step(self, batch, batch_idx):
output = self.layer(batch["x"])
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
output = self.layer(batch["x"])
loss = self.loss(batch, output)
self.log("x", loss)
return {"x": loss}

# fake data
train_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))
val_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))

# model
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model, train_data, val_data)

0 comments on commit c510a7f

Please sign in to comment.