Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional test for logging during validation loop #3907

Merged
merged 3 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def __set_meta(
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)

def track_batch_size(self, batch):
batch_size = self.unpack_batch_size(batch)
try:
batch_size = self.unpack_batch_size(batch)
except RecursionError as re:
batch_size = 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon I'm not too sure about the correctness of setting this to 1, or of using len(sample) in unpack_batch_size() when sample is a string. This way, the batch size will always be 1, even we have a list of strings with batch_size>1 coming from the dataloader. You can try to add a batch_size to the dataloader in the test, and unpack_batch_size will still indicate 1 as the batch size with the current patch.


meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)

Expand Down Expand Up @@ -330,6 +334,8 @@ def unpack_batch_size(self, sample):
"""
if isinstance(sample, torch.Tensor):
size = sample.size(0)
elif isinstance(sample, str):
return len(sample)
elif isinstance(sample, dict):
sample = next(iter(sample.values()), 1)
size = self.unpack_batch_size(sample)
Expand Down Expand Up @@ -406,7 +412,10 @@ def reduce_on_epoch_end(cls, outputs):
if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
reduced_val = weighted_mean(result[k], batch_sizes)
try:
reduced_val = weighted_mean(result[k], batch_sizes)
except Exception as e:
reduced_val = torch.mean(result[k])
else:
reduced_val = fx(result[k])

Expand Down
12 changes: 12 additions & 0 deletions tests/base/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def __len__(self):
return self.len


class RandomDictStringDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return {"id": str(index), "x": self.data[index]}

def __len__(self):
return self.len


class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
Expand Down
34 changes: 33 additions & 1 deletion tests/trainer/logging/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Tests to ensure that the training loop works with a dict (1.0)
"""
from tests.base.boring_model import BoringModel, RandomDictDataset
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
import os
import torch
import pytest
Expand Down Expand Up @@ -387,3 +387,35 @@ def val_dataloader(self):
}

assert generated == expected


def test_validation_step_with_string_data_logging():
class TestModel(BoringModel):
def on_train_epoch_start(self) -> None:
print("override any method to prove your bug")

def training_step(self, batch, batch_idx):
output = self.layer(batch["x"])
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
output = self.layer(batch["x"])
loss = self.loss(batch, output)
self.log("x", loss)
return {"x": loss}

# fake data
train_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))
val_data = torch.utils.data.DataLoader(RandomDictStringDataset(32, 64))

# model
model = TestModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model, train_data, val_data)