Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 124 additions & 15 deletions collections/nemo_nlp/nemo_nlp/data/data_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,33 @@
'BertPretrainingDataLayer',
'TranslationDataLayer']

# from abc import abstractmethod
import sys

import torch
from torch.utils import data as pt_data

import nemo
from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import *

from .datasets import *


class TextDataLayer(DataLayerNM):
"""
Generic Text Data Layer NM which wraps PyTorch's dataset


Args:
dataset: a PyTorch dataset to wrap into Neural Module
batch_size: batch size

dataset_type: type of dataset used for this datalayer
dataset_params (dict): all the params for the dataset
"""

def __init__(self, dataset, **kwargs):
DataLayerNM.__init__(self, **kwargs)
self._dataset = dataset
def __init__(self, dataset_type, dataset_params, **kwargs):
super().__init__(**kwargs)
if isinstance(dataset_type, str):
dataset_type = getattr(sys.modules[__name__], dataset_type)
self._dataset = dataset_type(**dataset_params)

def __len__(self):
return len(self._dataset)
Expand Down Expand Up @@ -80,6 +85,23 @@ def create_ports():
}
return {}, output_ports

def __init__(self,
input_file,
tokenizer,
max_seq_length,
num_samples=-1,
shuffle=False,
batch_size=64,
dataset_type=BertSentenceClassificationDataset,
**kwargs):
kwargs['batch_size'] = batch_size
dataset_params = {'input_file': input_file,
'tokenizer': tokenizer,
'max_seq_length': max_seq_length,
'num_samples': num_samples,
'shuffle': shuffle}
super().__init__(dataset_type, dataset_params, **kwargs)


class BertJointIntentSlotDataLayer(TextDataLayer):
"""
Expand Down Expand Up @@ -121,6 +143,27 @@ def create_ports():
}
return {}, output_ports

def __init__(self,
input_file,
slot_file,
pad_label,
tokenizer,
max_seq_length,
num_samples=-1,
shuffle=False,
batch_size=64,
dataset_type=BertJointIntentSlotDataset,
**kwargs):
kwargs['batch_size'] = batch_size
dataset_params = {'input_file': input_file,
'slot_file': slot_file,
'pad_label': pad_label,
'tokenizer': tokenizer,
'max_seq_length': max_seq_length,
'num_samples': num_samples,
'shuffle': shuffle}
super().__init__(dataset_type, dataset_params, **kwargs)


class BertJointIntentSlotInferDataLayer(TextDataLayer):
"""
Expand Down Expand Up @@ -155,6 +198,19 @@ def create_ports():
}
return {}, output_ports

def __init__(self,
queries,
tokenizer,
max_seq_length,
batch_size=1,
dataset_type=BertJointIntentSlotInferDataset,
**kwargs):
kwargs['batch_size'] = batch_size
dataset_params = {'queries': queries,
'tokenizer': tokenizer,
'max_seq_length': max_seq_length}
super().__init__(dataset_type, dataset_params, **kwargs)


class LanguageModelingDataLayer(TextDataLayer):
@staticmethod
Expand All @@ -180,6 +236,19 @@ def create_ports():

return input_ports, output_ports

def __init__(self,
dataset,
tokenizer,
max_seq_length,
batch_step=128,
dataset_type=LanguageModelingDataset,
**kwargs):
dataset_params = {'dataset': dataset,
'tokenizer': tokenizer,
'max_seq_length': max_seq_length,
'batch_step': batch_step}
super().__init__(dataset_type, dataset_params, **kwargs)


class BertTokenClassificationDataLayer(TextDataLayer):
@staticmethod
Expand All @@ -206,6 +275,19 @@ def create_ports():
}
return input_ports, output_ports

def __init__(self,
input_file,
tokenizer,
max_seq_length,
batch_size=64,
dataset_type=BertTokenClassificationDataset,
**kwargs):
kwargs['batch_size'] = batch_size
dataset_params = {'input_file': input_file,
'tokenizer': tokenizer,
'max_seq_length': max_seq_length}
super().__init__(dataset_type, dataset_params, **kwargs)

def eval_preds(self, logits, seq_ids, tag_ids):
return self._dataset.eval_preds(logits, seq_ids, tag_ids)

Expand Down Expand Up @@ -240,6 +322,20 @@ def create_ports():

return input_ports, output_ports

def __init__(self,
tokenizer,
dataset,
max_seq_length,
mask_probability,
batch_size=64,
**kwargs):
kwargs['batch_size'] = batch_size
dataset_params = {'tokenizer': tokenizer,
'dataset': dataset,
'max_seq_length': max_seq_length,
'mask_probability': mask_probability}
super().__init__(BertPretrainingDataset, dataset_params, **kwargs)


class TranslationDataLayer(TextDataLayer):
@staticmethod
Expand Down Expand Up @@ -273,20 +369,33 @@ def create_ports():

return input_ports, output_ports

def __init__(self, dataset, **kwargs):
TextDataLayer.__init__(self, None, **kwargs)
def __init__(self,
tokenizer_src,
tokenizer_tgt,
dataset_src,
dataset_tgt,
tokens_in_batch=1024,
clean=False,
dataset_type=TranslationDataset,
**kwargs):
dataset_params = {'tokenizer_src': tokenizer_src,
'tokenizer_tgt': tokenizer_tgt,
'dataset_src': dataset_src,
'dataset_tgt': dataset_tgt,
'tokens_in_batch': tokens_in_batch,
'clean': clean}
super().__init__(dataset_type, dataset_params, **kwargs)

if self._placement == nemo.core.DeviceType.AllGpu:
sampler = pt_data.distributed.DistributedSampler(self._dataset)
else:
sampler = None

self._dataloader = pt_data.DataLoader(
dataset=dataset,
batch_size=1,
collate_fn=lambda x: self._collate_fn(x),
shuffle=sampler is None,
sampler=sampler)
self._dataloader = pt_data.DataLoader(dataset=self._dataset,
batch_size=1,
collate_fn=self._collate_fn,
shuffle=sampler is None,
sampler=sampler)

def _collate_fn(self, x):
src_ids, src_mask, tgt_ids, tgt_mask, labels, sent_ids = x[0]
Expand Down
4 changes: 2 additions & 2 deletions collections/nemo_nlp/nemo_nlp/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .joint_intent_slot import (BertJointIntentSlotDataset,
BertJointIntentSlotInferDataset)
from .language_modeling import LanguageModelingDataset
from .ner import BertNERDataset
from .sentence_classification import BertSentenceClassificationDataset
from .ner import BertCornellNERDataset
from .token_classification import BertTokenClassificationDataset
from .sentence_classification import BertSentenceClassificationDataset
from .translation import TranslationDataset
50 changes: 27 additions & 23 deletions collections/nemo_nlp/nemo_nlp/data/datasets/bert_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BertPretrainingDataset(Dataset):
def __init__(self,
tokenizer,
dataset,
max_length=128,
max_seq_length=128,
mask_probability=0.15,
sentence_idx_file=None):
self.tokenizer = tokenizer
Expand Down Expand Up @@ -71,8 +71,8 @@ def find_newlines(contents):
.replace(b"\xa0", b" ")
num_tokens = len(line.split())

# Ensure the line has at least max_length tokens
if num_tokens >= max_length:
# Ensure the line has at least max_seq_length tokens
if num_tokens >= max_seq_length:
yield start - 1
used_tokens += num_tokens

Expand Down Expand Up @@ -104,8 +104,7 @@ def find_newlines(contents):
with open(sentence_idx_file, "wb") as f:
pickle.dump(sentence_indices, f)

print("Used {} tokens of total {}".format(used_tokens,
total_tokens))
print(f"Used {used_tokens} of total {total_tokens} tokens")

corpus_size = 0
empty_files = []
Expand All @@ -125,22 +124,21 @@ def find_newlines(contents):
self.dataset = dataset
self.filenames = list(sentence_indices.keys())
self.mask_probability = mask_probability
self.max_length = max_length
self.max_seq_length = max_seq_length
self.sentence_indices = sentence_indices
self.vocab_size = self.tokenizer.vocab_size

def __len__(self):
return self.corpus_size

def __getitem__(self, idx):
def __getitem__(self, idx, min_doc_length=16):
# Each sequence has three special tokens, as follows:
# [CLS] <document a> [SEP] <document b> [SEP]
num_special_tokens = 3

# TODO: Make seq_length = 512 for the last 10% of epochs, as specified
# in BERT paper
seq_length = self.max_length - num_special_tokens
min_doc_length = 16
seq_length = self.max_seq_length - num_special_tokens

a_length = random.randrange(min_doc_length,
seq_length - min_doc_length + 1)
Expand All @@ -160,11 +158,11 @@ def get_document(filepath, line):

# Read line, remove newline, and decode as UTF8
doc_text = f.readline()[:-1].decode("utf-8", errors="ignore")
document = [self.tokenizer.token_to_id("[CLS]")] \
+ self.tokenizer.text_to_ids(doc_text) \
+ [self.tokenizer.token_to_id("[SEP]")]
document = [self.tokenizer.token_to_id("[CLS]")]
document.extend(self.tokenizer.text_to_ids(doc_text))
document.append(self.tokenizer.token_to_id("[SEP]"))

assert len(document) >= self.max_length
assert len(document) >= self.max_seq_length
return document

a_document = get_document(a_filename, a_line)
Expand Down Expand Up @@ -194,20 +192,26 @@ def get_document(filepath, line):
a_ids = a_document[a_start_idx:a_start_idx + a_length]
b_ids = b_document[b_start_idx:b_start_idx + b_length]

output_ids = [self.tokenizer.special_tokens["[CLS]"]] + a_ids + \
[self.tokenizer.special_tokens["[SEP]"]] + b_ids + \
[self.tokenizer.special_tokens["[SEP]"]]
output_ids = [self.tokenizer.special_tokens["[CLS]"]]
output_ids.extend(a_ids)
output_ids.append(self.tokenizer.special_tokens["[SEP]"])
output_ids.extend(b_ids)
output_ids.append(self.tokenizer.special_tokens["[SEP]"])

input_ids, output_mask = self.mask_ids(output_ids)

output_mask = np.array(output_mask, dtype=np.float32)
input_mask = np.ones(self.max_length, dtype=np.float32)
input_mask = np.ones(self.max_seq_length, dtype=np.float32)

input_type_ids = np.zeros(self.max_length, dtype=np.int)
input_type_ids = np.zeros(self.max_seq_length, dtype=np.int)
input_type_ids[a_length + 2:seq_length + 3] = 1

return np.array(input_ids), input_type_ids, input_mask, \
np.array(output_ids), output_mask, label
return (np.array(input_ids),
input_type_ids,
input_mask,
np.array(output_ids),
output_mask,
label)

def mask_ids(self, ids):
"""
Expand All @@ -226,16 +230,16 @@ def mask_ids(self, ids):
"""
masked_ids = []
output_mask = []
for id in ids:
for i in ids:
if random.random() < self.mask_probability:
output_mask.append(1)
if random.random() < 0.8:
masked_ids.append(self.tokenizer.special_tokens["[MASK]"])
elif random.random() < 0.5:
masked_ids.append(random.randrange(self.vocab_size))
else:
masked_ids.append(id)
masked_ids.append(i)
else:
masked_ids.append(id)
masked_ids.append(i)
output_mask.append(0)
return masked_ids, output_mask
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class LanguageModelingDataset(Dataset):
def __init__(self,
tokenizer,
dataset,
max_sequence_length=512,
max_seq_length=512,
batch_step=None):
self.tokenizer = tokenizer
self.max_seq_length = max_sequence_length
self.max_seq_length = max_seq_length
self.batch_step = batch_step or self.max_seq_length
ids = utils.dataset_to_ids(dataset, tokenizer, add_bos_eos=False)
self.ids = np.array([j for i in ids for j in i])
Expand Down
2 changes: 1 addition & 1 deletion collections/nemo_nlp/nemo_nlp/data/datasets/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.utils.data import Dataset


class BertNERDataset(Dataset):
class BertCornellNERDataset(Dataset):
def __init__(self, input_file, max_seq_length, tokenizer):
# Read the sentences and group them in sequences up to max_seq_length
with open(input_file, "r") as f:
Expand Down
4 changes: 2 additions & 2 deletions collections/nemo_nlp/nemo_nlp/data/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def process_sst_2(data_dir):
if not os.path.exists(data_dir):
link = 'https://gluebenchmark.com/tasks'
raise ValueError(f'Data not found at {data_dir}. '
'Please download SST-2 from {link}.')
f'Please download SST-2 from {link}.')
logger.info('Keep in mind that SST-2 is only available in lower case.')
return data_dir

Expand All @@ -68,7 +68,7 @@ def process_imdb(data_dir, uncased, modes=['train', 'test']):
if not os.path.exists(data_dir):
link = 'www.kaggle.com/iarunava/imdb-movie-reviews-dataset'
raise ValueError(f'Data not found at {data_dir}. '
'Please download IMDB from {link}.')
f'Please download IMDB from {link}.')

outfold = f'{data_dir}/nemo-processed'

Expand Down
Loading