Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/torchtext include lengths #2689

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
924e4c1
Test using torchtext.data.Field with include_lengths=True/False
Jul 24, 2020
1fcbe36
Fix issue that Tensors in a Batch generated by torchtext with torchte…
Jul 24, 2020
59e97b2
Add description for fix of issue #2688
Jul 24, 2020
3e3fbbe
changes to accomodate CodeFactor issues
Jul 24, 2020
fe9816d
Another attemt to make last CodeFactor issue pass (it's a false alarm)
Jul 24, 2020
957ee89
temporarly disable test of test_grad_tracking to check if testing wil…
Jul 24, 2020
7971e7d
reenable test in test_grad_norm
Jul 25, 2020
4d0a849
Update CHANGELOG.md
thschaaf Jul 26, 2020
c994e88
Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator a…
Jul 26, 2020
f60613c
Update pytorch_lightning/utilities/apply_func.py
thschaaf Jul 26, 2020
c9fdf50
adding tests more specific to batch_move_data_to_device with tochtext…
Jul 27, 2020
5e568ea
added check that Tensors were moved to target device
Jul 27, 2020
a6b96b0
removed tests using RNN models to be moved into a separate PR
Jul 27, 2020
0eabe91
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 27, 2020
398ab54
fixing FLAKE8 errors that showed up after merge from master branch
Jul 27, 2020
8d56dc8
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 27, 2020
a99fc7d
parameterized test to reduce code duplication
Jul 28, 2020
61e692f
Added check only if length tensor exist. Removed left over comments.
Jul 28, 2020
0c25f43
rearranged device parameterization and added pytest.param
Jul 28, 2020
f08dd78
Try to figure out why only one device is tested on Linux machines
Jul 28, 2020
d2c4598
Testing on CPU and GPU devices (GPU test is skip if no cuda device is…
Jul 28, 2020
9bd3854
added test for TPU device (experimental)
Jul 28, 2020
d04c288
Adding test parameterization for TPU test (experimental)
Jul 28, 2020
cca6ff3
change import statement to limit what is imported for a TPU environment
Jul 28, 2020
5f3680d
made test work with TPU
Jul 28, 2020
08ebb6d
Change to trigger CI
Jul 28, 2020
fa6b2f9
Change to trigger CI
Jul 28, 2020
f9d9887
Merge branch 'bugfix/torchtext-include_lengths' of https://github.com…
Jul 28, 2020
940c34d
uncommented TPU test to check CI
Jul 28, 2020
584328a
reenabling TPU test
Jul 29, 2020
ae71b14
small change to trigger CI build
Jul 29, 2020
34201bc
small change to trigger CI build
Jul 29, 2020
a53a469
small change to trigger CI build
Jul 29, 2020
647e44b
adding tests/utilities/test_apply_func_torchtext.py to CI TPU test
Jul 29, 2020
ff080da
try to make test not skipped on CI with TPU
Jul 29, 2020
43a5ea9
remove testing on TPU
Jul 29, 2020
73583c1
undo an accidental change to test_tpu.py (file should not have been t…
Jul 29, 2020
b929711
small change to trigger CI build
Jul 29, 2020
68e2152
small change to trigger CI build
Jul 29, 2020
c97cd69
Update tests/utilities/test_apply_func_torchtext.py
awaelchli Jul 29, 2020
1685077
Revert to previous version
Jul 29, 2020
8a7d68b
Apply suggestions from code review
Borda Jul 29, 2020
72f64ad
Merge branch 'master' into bugfix/torchtext-include_lengths
thschaaf Jul 29, 2020
3c04090
Change to trigger CI
Jul 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

## [0.8.5] - 2020-07-09

### Added
Expand Down
28 changes: 23 additions & 5 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import importlib

TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None
if TORCHTEXT_AVAILABLE:
from torchtext.data import Batch
Expand Down Expand Up @@ -92,18 +93,35 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""

def batch_to(data):
# try to move torchtext data first
if TORCHTEXT_AVAILABLE and isinstance(data, Batch):

# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field in data.fields:
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
setattr(device_data, field, device_field)
# Batch contains output of Field.process(...)
thschaaf marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(getattr(data, field), torch.Tensor):
# standard case: usually a tensor hence .to(...) exists
device_field = getattr(data, field).to(device, non_blocking=True)
setattr(device_data, field, device_field)
elif isinstance(getattr(data, field), tuple):
# Case of include_lengths=True then torchtext produces a tuple of two tensors
# Use of generator expression to send Tensors to device (alternative could be list comprehension)
device_field = tuple(elem.to(device, non_blocking=True) for elem in getattr(data, field))
setattr(device_data, field, device_field)
elif isinstance(getattr(data, field), list):
# Case for completeness
device_field = list(elem.to(device, non_blocking=True) for elem in getattr(data, field))
setattr(device_data, field, device_field)
else:
# Catch all assuming the class has a .to if not it will fail; and more cases are needed
device_field = getattr(data, field).to(device, non_blocking=True)
setattr(device_data, field, device_field)

thschaaf marked this conversation as resolved.
Show resolved Hide resolved
return device_data
else:
return data.to(device, non_blocking=True)

return data.to(device, non_blocking=True)

return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
170 changes: 170 additions & 0 deletions tests/utilities/test_apply_func_torchtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
import torchtext

import pytorch_lightning as pl


def _get_torchtext_data_iterator(include_lengths=False):
text_field = torchtext.data.Field(sequential=True, pad_first=False, # nosec
init_token="<s>", eos_token="</s>", include_lengths=include_lengths) # nosec

example1 = torchtext.data.example.Example.fromdict({"text": "a b c a c"}, {"text": ("text", text_field)})
example2 = torchtext.data.example.Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)})
example3 = torchtext.data.example.Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)})

dataset = torchtext.data.Dataset([example1, example2, example3], {"text": text_field})
text_field.build_vocab(dataset)

iterator = torchtext.data.Iterator(dataset, batch_size=3,
sort_key=None, device=None, batch_size_fn=None,
train=True, repeat=False, shuffle=None, sort=None, sort_within_batch=None)
return iterator, text_field


def test_move_data_to_device_torchtext_include_length_true():
"""Test if batches created by torchtext with include_lengths=True raise an exception."""

class DebugModel(pl.LightningModule):
thschaaf marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
super(DebugModel, self).__init__()

# setup data loader generating batches with fields consisting of tuples of tensors
self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=True)

self.learning_rate = 0.001

pad_idx = self.text_field.vocab.stoi['<pad>']
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)

self.INPUT_DIM = len(self.text_field.vocab)
self.ENC_EMB_DIM = 4 # keep it small for debugging
self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM)

self.hid_dim = 4
self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False)
self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

def forward(self, input_seq, length):
embedded = self.embedding(input_seq)
packed_embedded = torch.nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=False,
enforce_sorted=False)
packed_outputs, hidden = self.rnn(packed_embedded)
outputs, length = torch.nn.utils.rnn.pad_packed_sequence(packed_outputs)

output = outputs.squeeze(0)
prediction = self.out(output)

return prediction

@staticmethod
def _parse_batch(batch):
source = batch.text[0]
source_length = batch.text[1]

return source, source_length

def training_step(self, batch, batch_nb):
""" Needed for testing data transfer. """
x = self._parse_batch(batch)
target, target_length = x

output = self.forward(target, target_length)
loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1))

prefix = 'train'
tensorboard_logs = {f'{prefix}_loss': loss.item()}

result = {'loss': loss, 'log': tensorboard_logs}
return result

def train_dataloader(self):
return self.debug_data_loader

model = DebugModel()

cuda_device_cnt = torch.cuda.device_count()
if cuda_device_cnt > 0:
use_num_cuda_devices = 1
else:
use_num_cuda_devices = None

trainer = pl.Trainer(fast_dev_run=True, max_steps=None,
gradient_clip_val=10,
weights_summary=None, gpus=use_num_cuda_devices,
show_progress_bar=True)

result = trainer.fit(model)
# verify training completed
assert result == 1


def test_move_data_to_device_torchtext_include_length_false():
"""Test if batches created by torchtext with include_lengths=False raise an exception."""

class DebugModel(pl.LightningModule):

def __init__(self):
super(DebugModel, self).__init__()

# setup data loader generating batches with fields consisting of tensors
self.debug_data_loader, self.text_field = _get_torchtext_data_iterator(include_lengths=False)

self.learning_rate = 0.001

pad_idx = self.text_field.vocab.stoi['<pad>']
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)

self.INPUT_DIM = len(self.text_field.vocab)
self.ENC_EMB_DIM = 4 # keep it small for debugging
self.embedding = torch.nn.Embedding(self.INPUT_DIM, self.ENC_EMB_DIM)

self.hid_dim = 4
self.rnn = torch.nn.GRU(self.ENC_EMB_DIM, self.hid_dim, 1, bidirectional=False)
self.out = torch.nn.Linear(self.hid_dim, self.embedding.num_embeddings)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

def forward(self, input_seq):
embedded = self.embedding(input_seq)
outputs, hidden = self.rnn(embedded)
output = outputs.squeeze(0)
prediction = self.out(output)
return prediction

def training_step(self, batch, batch_nb):
""" Needed for testing data transfer. """

target = batch.text
output = self.forward(target)
loss = self.criterion(output[:-1].view(-1, output.shape[2]), target[1:].view(-1))

prefix = 'train'
tensorboard_logs = {f'{prefix}_loss': loss.item()}

result = {'loss': loss, 'log': tensorboard_logs}
return result

def train_dataloader(self):
return self.debug_data_loader

model = DebugModel()

cuda_device_cnt = torch.cuda.device_count()
if cuda_device_cnt > 0:
use_num_cuda_devices = 1
else:
use_num_cuda_devices = None

trainer = pl.Trainer(fast_dev_run=True, max_steps=None,
gradient_clip_val=10,
weights_summary=None, gpus=use_num_cuda_devices,
show_progress_bar=True)

result = trainer.fit(model)
# verify training completed
assert result == 1