diff --git a/README.md b/README.md index 0804d67893a..779569eee5e 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ python pretrain_bert.py \ --tokenizer-model-type bert-large-uncased \ --vocab-size 30522 \ --train-data wikipedia \ + --presplit-sentences \ --loose-json \ --text-key text \ --split 1000,1,1 \ @@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten # Collecting Wikipedia Training Data We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text." -We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. +We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend further preprocessing this json dataset by preprocessing the dataset with nltk punctuation standardization, and presplitting each document into newline separated sentences. This can be done with the provided script `./scripts/presplit_sentences_json.py` and will allow for faster data processing during training time. Pretraining with presplit data should be run with the `--presplit-sentences` flag as shown above. Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`. diff --git a/arguments.py b/arguments.py index d7d554e6331..cfe3a85c108 100644 --- a/arguments.py +++ b/arguments.py @@ -184,6 +184,9 @@ def add_data_args(parser): group = parser.add_argument_group('data', 'data configurations') + group.add_argument('--shuffle', action='store_true', + help='Shuffle data. Shuffling is deterministic ' + 'based on seed and current epoch.') group.add_argument('--train-data', nargs='+', required=True, help='Filename (or whitespace separated filenames) ' 'for training.') @@ -208,6 +211,9 @@ def add_data_args(parser): help='Use loose json (one json-formatted string per ' 'newline), instead of tight json (data file is one ' 'json string)') + group.add_argument('--presplit-sentences', action='store_true', + help='Dataset content consists of documents where ' + 'each document consists of newline separated sentences') group.add_argument('--num-workers', type=int, default=2, help="""Number of workers to use for dataloading""") group.add_argument('--tokenizer-model-type', type=str, diff --git a/configure_data.py b/configure_data.py index fa1dd92baae..0c2ea7e1a60 100644 --- a/configure_data.py +++ b/configure_data.py @@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args): shuffle = args.shuffle if shuffle: - sampler = torch.utils.data.RandomSampler(dataset) + sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) else: sampler = torch.utils.data.SequentialSampler(dataset) world_size = args.world_size @@ -81,8 +81,10 @@ def make_tfrecord_loaders(args): 'max_seq_len': args.seq_length, 'max_preds_per_seq': args.max_preds_per_seq, 'train': True, - 'num_workers': args.num_workers, - 'seed': args.seed+args.rank+1} + 'num_workers': max(args.num_workers, 1), + 'seed': args.seed + args.rank + 1, + 'threaded_dl': args.num_workers > 0 + } train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, **data_set_args) data_set_args['train'] = False @@ -140,7 +142,8 @@ def make_loaders(args): 'vocab_size': args.vocab_size, 'model_type': args.tokenizer_model_type, 'cache_dir': args.cache_dir, - 'max_preds_per_seq': args.max_preds_per_seq} + 'max_preds_per_seq': args.max_preds_per_seq, + 'presplit_sentences': args.presplit_sentences} eval_set_args = copy.copy(data_set_args) eval_set_args['split'] = [1.] @@ -218,7 +221,6 @@ def configure_data(): 'rank': -1, 'persist_state': 0, 'lazy': False, - 'shuffle': False, 'transpose': False, 'data_set_type': 'supervised', 'seq_length': 256, diff --git a/data_utils/__init__.py b/data_utils/__init__.py index 7a60f97b32d..d58622c58a1 100644 --- a/data_utils/__init__.py +++ b/data_utils/__init__.py @@ -46,7 +46,7 @@ def get_dataset(path, **kwargs): if supported_corpus(path): return corpora.NAMED_CORPORA[path](**kwargs) ext = get_ext(path) - if ext =='.json': + if '.json' in ext: text = json_dataset(path, **kwargs) elif ext in ['.csv', '.tsv']: text = csv_dataset(path, **kwargs) @@ -108,8 +108,10 @@ def get_dataset_from_path(path_): if should_split(split): ds = split_ds(ds, split) if ds_type.lower() == 'bert': - ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length) for d in ds] + presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False + ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) for d in ds] else: if ds_type.lower() == 'bert': - ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length) + presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False + ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences) return ds, tokenizer diff --git a/data_utils/datasets.py b/data_utils/datasets.py index 88c2a1ccd5d..7eaa2bbfb66 100644 --- a/data_utils/datasets.py +++ b/data_utils/datasets.py @@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset): dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) """ - def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, **kwargs): + def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, **kwargs): self.ds = ds self.ds_len = len(self.ds) self.tokenizer = self.ds.GetTokenizer() @@ -464,6 +464,7 @@ def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None self.dataset_size = dataset_size if self.dataset_size is None: self.dataset_size = self.ds_len * (self.ds_len-1) + self.presplit_sentences = presplit_sentences def __len__(self): return self.dataset_size @@ -494,7 +495,14 @@ def __getitem__(self, idx): def sentence_split(self, document): """split document into sentences""" - return tokenize.sent_tokenize(document) + lines = document.split('\n') + if self.presplit_sentences: + return [line for line in lines if line] + rtn = [] + for line in lines: + if line != '': + rtn.extend(tokenize.sent_tokenize(line)) + return rtn def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False): """tokenize sentence and get token types""" diff --git a/data_utils/samplers.py b/data_utils/samplers.py index 4e086905dcc..2e34ff9fe85 100644 --- a/data_utils/samplers.py +++ b/data_utils/samplers.py @@ -21,6 +21,57 @@ from torch.utils import data import numpy as np +class RandomSampler(data.sampler.Sampler): + r""" + Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, + but this class lets the user set an epoch like DistributedSampler + Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify ``num_samples`` to draw. + Arguments: + data_source (Dataset): dataset to sample from + num_samples (int): number of samples to draw, default=len(dataset) + replacement (bool): samples are drawn with replacement if ``True``, default=False + """ + + def __init__(self, data_source, replacement=False, num_samples=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.epoch = -1 + + if self._num_samples is not None and replacement is False: + raise ValueError("With replacement=False, num_samples should not be specified, " + "since a random permute will be performed.") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples)) + if not isinstance(self.replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(self.replacement)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + g = torch.Generator() + if self.epoch >= 0: + g.manual_seed(self.epoch) + if self.replacement: + return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) + return iter(torch.randperm(n, generator=g).tolist()) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + class DistributedBatchSampler(data.sampler.BatchSampler): """ similar to normal implementation of distributed sampler, except implementation is at the diff --git a/data_utils/tf_dl.py b/data_utils/tf_dl.py index a29376f0a59..29b4056de83 100755 --- a/data_utils/tf_dl.py +++ b/data_utils/tf_dl.py @@ -14,12 +14,16 @@ # limitations under the License. """PyTorch DataLoader for TFRecords""" +import queue +import threading + import tensorflow as tf tf.enable_eager_execution() import torch +import numpy as np class TFRecordDataLoader(object): - def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1): + def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False): assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords" tf.set_random_seed(seed) if isinstance(records, str): @@ -55,11 +59,18 @@ def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, n 'num_parallel_batches': num_workers, 'drop_remainder': train} self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args)) + self.threaded_dl = threaded_dl + self.num_workers = num_workers def __iter__(self): - data_iter = iter(self.dataloader) - for item in data_iter: - yield convert_tf_example_to_torch_tensors(item) + if self.threaded_dl: + data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers)) + for item in data_iter: + yield item + else: + data_iter = iter(self.dataloader) + for item in data_iter: + yield convert_tf_example_to_torch_tensors(item) class Record2Example(object): def __init__(self, feature_map): @@ -74,14 +85,37 @@ def __call__(self, record): return example def convert_tf_example_to_torch_tensors(example): - item = {k: torch.from_numpy(v.numpy()) for k,v in example.items()} - mask = torch.zeros_like(item['input_ids']) - mask_labels = torch.ones_like(item['input_ids'])*-1 - for b, row in enumerate(item['masked_lm_positions'].long()): + item = {k: (v.numpy()) for k,v in example.items()} + mask = np.zeros_like(item['input_ids']) + mask_labels = np.ones_like(item['input_ids'])*-1 + for b, row in enumerate(item['masked_lm_positions'].astype(int)): for i, idx in enumerate(row): if item['masked_lm_weights'][b, i] != 0: mask[b, idx] = 1 mask_labels[b, idx] = item['masked_lm_ids'][b, i] - return {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], - 'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} + output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], + 'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} + return {k: torch.from_numpy(v) for k,v in output.items()} + +class MultiprocessLoader(object): + def __init__(self, dataloader, num_workers=2): + self.dl = dataloader + self.queue_size = 2*num_workers + + def __iter__(self): + output_queue = queue.Queue(self.queue_size) + output_thread = threading.Thread(target=_multiproc_iter, + args=(self.dl, output_queue)) + output_thread.daemon = True + output_thread.start() + + while output_thread.is_alive(): + yield output_queue.get(block=True) + else: + print(RuntimeError('TF record data loader thread exited unexpectedly')) +def _multiproc_iter(dl, output_queue): + data_iter = iter(dl) + for item in data_iter: + tensors = convert_tf_example_to_torch_tensors(item) + output_queue.put(tensors, block=True) \ No newline at end of file diff --git a/pretrain_bert.py b/pretrain_bert.py index 3100e789336..8779dd96d39 100755 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -434,6 +434,8 @@ def main(): train_data.batch_sampler.start_iter = total_iters % len(train_data) # For all epochs. for epoch in range(start_epoch, args.epochs+1): + if args.shuffle: + train_data.batch_sampler.sampler.set_epoch(epoch+args.seed) timers('epoch time').start() iteration, skipped = train_epoch(epoch, model, optimizer, train_data, lr_scheduler,