diff --git a/trax/data/__init__.py b/trax/data/__init__.py index ec67dc58c..ed126c885 100644 --- a/trax/data/__init__.py +++ b/trax/data/__init__.py @@ -14,6 +14,32 @@ # limitations under the License. """Data imports in Trax.""" +import gin from trax.data import inputs from trax.data import tf_inputs + + +# Ginify +def data_configure(*args, **kwargs): + kwargs['module'] = 'trax.data' + return gin.external_configurable(*args, **kwargs) + + +# pylint: disable=invalid-name +AddLossWeights = data_configure(inputs.AddLossWeights) +add_loss_weights = inputs.add_loss_weights +Batch = data_configure(inputs.Batch) +batch = inputs.batch +BucketByLength = data_configure(inputs.BucketByLength) +bucket_by_length = inputs.bucket_by_length +FilterByLength = data_configure(inputs.FilterByLength) +Log = data_configure(inputs.Log) +Serial = data_configure(inputs.Serial) +Shuffle = data_configure(inputs.Shuffle) +shuffle = inputs.shuffle +TFDS = data_configure(tf_inputs.TFDS) +Tokenize = data_configure(tf_inputs.Tokenize) +tokenize = tf_inputs.tokenize +detokenize = tf_inputs.detokenize +vocab_size = tf_inputs.vocab_size diff --git a/trax/data/inputs.py b/trax/data/inputs.py index a03ce9705..a50b193f4 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -14,7 +14,59 @@ # limitations under the License. # Lint as: python3 -"""Trax input pipeline.""" +"""Trax input pipeline. + +In Trax we encourage to use combinators to construct input pipelines in a way +that resembles layer combinators. Here is an example of an input pipeline for +training sentiment analysis tasks on the IMDB dataset:: + + inputs = data.Serial( + data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), + data.Tokenize(vocab_file='en_8k.subword', keys=[0]), + data.Shuffle(), + data.FilterByLength(max_length=2048, length_keys=[0]), + data.BucketByLength(boundaries=[ 32, 128, 512, 2048], + batch_sizes=[128, 32, 8, 2, 1], + length_keys=[0]) + data.AddLossWeights() + ) + +Each of these combinators creates a python generator of tuples of data examples. +For example:: + + data.TFDS('imdb_reviews', keys=('text', 'label'), train=True), + +creates a generator of examples from the TFDS imdb_reviews dataset, see here: +https://www.tensorflow.org/datasets/catalog/imdb_reviews + +As you can see on the website above, this dataset has 'text' and 'label' fields +and we create tuples containing the text and the label from the training split +by specifying keys=('text', 'label'), train=True. + +The other combinators, like Tokenize and Shuffle, take a generator and output +another generator, in this way converting tuples into other tuples or mixing +the training stream. For example, Tokenize(..., keys=[0]) will tokenize the +first element of the tuple - and in this way convert it from text to a tensor of +integers. Shuffle will not change the exmples, but will randomize their order. + +Note that all elements in the data pipeline are just functions on generators, +so you can use python's `map` and `filter` and other native functions too. +For example, you can create an input pipeline for a language model reading +lines from `my_file.txt` as follows:: + + inputs = data.Serial( + lambda _: open('my_file.txt'), + lambda g: map(lambda line: line.strip(), g), + data.Tokenize(vocab_file='en_8k.subword'), + lambda g: filter(lambda x: x.shape[0] < 513, g), # At most 512 tokens. + data.Shuffle(), + lambda g: map(lambda x: (x, x)), # Language models have inputs = targets. + data.BucketByLength(boundaries=[ 32, 64, 128, 256, 512], + batch_sizes=[ 32, 16, 8, 4, 2, 1]), + data.AddLossWeights(id_to_mask=0) + ) + +""" import math import random @@ -24,189 +76,36 @@ import gin import numpy as np +from trax import fastmath +from trax import shapes from trax.fastmath import numpy as jnp -class Inputs(object): - """Inputs bundle. - - Inputs bundle holds input streams and shapes for a training run. - It contains stream-creating functions that return python generators - of (input_batch, target_batch) tuples. - - * train_stream: training data that will be used for training - may include all the augmentation or selection the training wants - the shape of examples is [batch_fn.batch_size, ...] - * train_eval_stream: training data used for evaluation - examples from training data but usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * eval_stream: evaluation data stream - examples from evaluation data, usually without augmentation - the shape of examples is [batch_fn.eval_batch_size, ...] - * input_shape: the shape of inputs - the [...] above, without batch size - * input_dtype: the data type of inputs - * target_shape: the shape of targets - the [...] above, without batch size - * target_dtype: the data type of targets - """ - - def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): - """Initialize a new set of inputs. - - Args: - train_stream: a function taking n_devices (an int) and returning - a python generator of training batches. - eval_stream: a function taking n_devices (an int) and returning - a python generator of validation batches; - if None, then the training generator will be used for evaluation. - train_eval_stream: a function taking n_devices (an int) and returning - a python generator of batches from - the training set used for evaluation (if None, use train_stream). - """ - if not callable(train_stream): - raise ValueError('Trax Inputs should be initialized with a function. ' - 'Did you forget the n_devices argument? If your inputs ' - 'do not use it, try lambda _: [your-inputs].') - - self._train_stream = train_stream - self._eval_stream = eval_stream or self._train_stream - - # TODO(lukaszkaiser): should we get rid of this one day? - self._train_eval_stream = train_eval_stream or self._train_stream - - # Peek into the train stream to get an example shape. - example_train_batch = next(train_stream(1)) - self._input_shape = tuple(example_train_batch[0].shape)[1:] - self._input_dtype = example_train_batch[0].dtype - self._target_shape = tuple(example_train_batch[-1].shape)[1:] - self._target_dtype = example_train_batch[-1].dtype - self._example_shape = [x.shape for x in example_train_batch] - self._example_dtype = [x.dtype for x in example_train_batch] - - def train_stream(self, n_devices): - return self._train_stream(n_devices) - - def eval_stream(self, n_devices): - return self._eval_stream(n_devices) - - def train_eval_stream(self, n_devices): - return self._train_stream(n_devices) - - @property - def input_shape(self): - """Example input shape, without batch dimension.""" - return self._input_shape - - @property - def target_shape(self): - """Example target shape, without batch dimension.""" - return self._target_shape - - @property - def input_dtype(self): - """Dtype of the input.""" - return self._input_dtype - - @property - def target_dtype(self): - """Dtype of the target.""" - return self._target_dtype - - @property - def example_shape_dtype(self): - """Shape and Dtype of an example batch.""" - return self._example_shape, self._example_dtype - - -# Batching and input pipeline creation helpers. - - -@gin.configurable() -def batcher(data_streams=gin.REQUIRED, variable_shapes=True, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - # TODO(afrozm): Unify padding logic. - id_to_mask=None, strict_pad_on_len=False): - """Batcher: create trax Inputs from single-example data-streams.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # For now leaving the arguments as in batch_fn to reduce gin config changes. - if callable(data_streams): # If we pass a function, e.g., through gin, call. - train_stream, eval_stream = data_streams() - else: - train_stream, eval_stream = data_streams - # pylint: disable=g-long-lambda - batch_train_stream = lambda n_devices: batch_fn( - train_stream(), True, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_eval_stream = lambda n_devices: batch_fn( - eval_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - batch_train_eval_stream = lambda n_devices: batch_fn( - train_stream(), False, n_devices, variable_shapes, - batch_size_per_device, batch_size, eval_batch_size, - bucket_length, buckets, buckets_include_inputs_in_length, - batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) - # pylint: enable=g-long-lambda - return Inputs(train_stream=batch_train_stream, - eval_stream=batch_eval_stream, - train_eval_stream=batch_train_eval_stream) - - -def batch_fn(dataset, training, n_devices, variable_shapes, - batch_size_per_device=32, batch_size=None, eval_batch_size=32, - bucket_length=32, buckets=None, - buckets_include_inputs_in_length=False, - batch_shuffle_size=None, max_eval_length=None, - id_to_mask=None, strict_pad_on_len=False): - """Batching function.""" - # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. - # After that, create a proper doc-string; we may also not need to pass both - # training and eval arguments here, as batcher calls the function separately - # now and it's not under gin-config any more -- consider reducing args. - batch_size = batch_size or batch_size_per_device * n_devices - # If bucketing is not specified, check if target shapes are variable. - cur_batch_size = batch_size if training else eval_batch_size - # Make cur_batch_size divisible by n_devices. - cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices - # Create heuristic buckets if none are specified. - if buckets is None: - logging.info('Heuristically setting bucketing to %s based on shapes ' - 'of target tensors.', variable_shapes) - if variable_shapes: - buckets = _buckets_for_length( - bucket_length, cur_batch_size, max_eval_length, n_devices, training) - - if buckets: - logging.info('Bucketing with buckets %s.', str(buckets)) - def example_length(x): - """The length function used by bucket_by_sequence_length to bucket.""" - # The input x is a tuple to go on the stack, typically either - # (input, target) or (input, target, mask). - example_inputs, target = x[0], x[1] - # Length is the shape of axis 0 here (no batch yet). - other_length = 0 # We include input length only if asked. - if buckets_include_inputs_in_length: - other_length = example_inputs.shape[0] - return max(target.shape[0], other_length) - boundaries, batch_sizes = buckets - dataset = bucket_by_length( - dataset, example_length, boundaries, batch_sizes, strict_pad_on_len) - else: - logging.info('Not Bucketing cur_batch_size %d.', cur_batch_size) - dataset = batch_data(dataset, cur_batch_size) - if training and batch_shuffle_size is not None: - dataset = shuffle_data(dataset, batch_shuffle_size) - return add_loss_weights(dataset, id_to_mask) +def Serial(*fns): # pylint: disable=invalid-name + """Creates an input pipeline by running all functions one after another.""" + generator = None + for f in fastmath.tree_flatten(fns): + generator = f(generator) + return generator + + +def Log(n_steps_per_example=1, only_shapes=True): # pylint: disable=invalid-name + """Creates a logging component of the input pipeline.""" + def log(stream): + counter = 0 + for example in stream: + item_to_log = example + if only_shapes: + item_to_log = fastmath.nested_map(shapes.signature, example) + if counter % n_steps_per_example == 0: + logging.info(str(item_to_log)) + print(item_to_log) + counter += 1 + yield example + return log -def shuffle_data(samples, queue_size): +def shuffle(samples, queue_size): """Shuffles a sample stream using a random-out next-in queue of given size. Args: @@ -246,7 +145,12 @@ def shuffle_data(samples, queue_size): yield sample -def batch_data(generator, batch_size): +def Shuffle(queue_size=1024): # pylint: disable=invalid-name + """Returns a shuffle function with the given queue size.""" + return lambda g: shuffle(g, queue_size) + + +def batch(generator, batch_size): """Batch and pad generator as in tf.data.Dataset.padded_batch.""" buf = [] # TODO(lukaszkaiser): convert to ValueError @@ -256,12 +160,17 @@ def batch_data(generator, batch_size): if len(buf) == batch_size: # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) - batch = tuple(np.stack(x) for x in zip(*buf)) + batched_example = tuple(np.stack(x) for x in zip(*buf)) # Note that it's the same shape as each example with added batch dim. - yield batch + yield batched_example buf = [] +def Batch(batch_size): # pylint: disable=invalid-name + """Returns a batching function with given batch size.""" + return lambda g: batch(g, batch_size) + + def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): """Pad a tuple of tensors to a joint dimension and return their batch. @@ -323,7 +232,7 @@ def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): for i in range(dim): max_len = max([t.shape[i] for t in tensors]) min_len = min([t.shape[i] for t in tensors]) - if max_len == min_len: # No padding needed. + if max_len == min_len and max_len == boundary: # No padding needed. max_len_to_pad.append(max_len) elif boundary is None: max_len_to_pad.append(max_len) @@ -377,8 +286,6 @@ def _buckets_for_length(bucket_length, batch_size, max_eval_length, n_devices, [max_eval_length] ) bucket_boundaries.append(max_eval_length) - # We will pad to boundaries which pads to bucket_boundary - 1: add 1 here. - bucket_boundaries = [b + 1 for b in bucket_boundaries] bucket_batch_sizes = [batch_size * 4, batch_size * 2, batch_size, batch_size // 2, batch_size // 4, batch_size // 8, @@ -425,18 +332,47 @@ def bucket_by_length(generator, length_fn, boundaries, batch_sizes, length = length_fn(example) # `bucket_idx` will always be < len(boundaries), since boundaries is right # padded by `math.inf`. - bucket_idx = min([i for i, b in enumerate(boundaries) if length < b]) + bucket_idx = min([i for i, b in enumerate(boundaries) if length <= b]) buckets[bucket_idx].append(example) if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: - batch = zip(*buckets[bucket_idx]) - boundary = boundaries[bucket_idx] - 1 + batched = zip(*buckets[bucket_idx]) + boundary = boundaries[bucket_idx] boundary = None if boundary == math.inf else boundary padded_batch = tuple( - pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batch) + pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batched) yield padded_batch buckets[bucket_idx] = [] +def _length_fn(example, length_axis, length_keys): + """Length is the maximum of shape on length_axis over length_keys.""" + if isinstance(example, (list, tuple)): + return max([example[i].shape[length_axis] for i in length_keys]) + return example.shape[length_axis] + + +def BucketByLength(boundaries, batch_sizes, # pylint: disable=invalid-name + length_keys=None, length_axis=0, strict_pad_on_len=False): + """Returns a function for bucketing inputs, see `bucket_by_length`.""" + length_keys = length_keys or [0, 1] + # In all cases so far, we use a length function of the following form. + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + return lambda g: bucket_by_length( # pylint: disable=g-long-lambda + g, length_fn, boundaries, batch_sizes, strict_pad_on_len) + + +def FilterByLength(max_length, # pylint: disable=invalid-name + length_keys=None, length_axis=0): + """Returns a function that filters out examples longer than `max_length`.""" + length_keys = length_keys or [0, 1] + length_fn = lambda x: _length_fn(x, length_axis, length_keys) + def filtered(gen): + for example in gen: + if length_fn(example) <= max_length: + yield example + return filtered + + def add_loss_weights(generator, id_to_mask=None): """Add weights to inputs without weights and masks by id if requested. @@ -470,6 +406,206 @@ def add_loss_weights(generator, id_to_mask=None): yield (example[0], example[1], weights) +def AddLossWeights(id_to_mask=None): # pylint: disable=invalid-name + """Returns a function to add loss weights; see `add_loss_weights`.""" + return lambda g: add_loss_weights(g, id_to_mask=id_to_mask) + + +# Inputs class used for setting up Trainer. +# Note: as we move from Trainer to Loop this class may become obsolete. + + +class Inputs(object): + """Inputs bundle. + + Inputs bundle holds input streams and shapes for a training run. + It contains stream-creating functions that return python generators + of (input_batch, target_batch) tuples. + + * train_stream: training data that will be used for training + may include all the augmentation or selection the training wants + the shape of examples is [batch_fn.batch_size, ...] + * train_eval_stream: training data used for evaluation + examples from training data but usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * eval_stream: evaluation data stream + examples from evaluation data, usually without augmentation + the shape of examples is [batch_fn.eval_batch_size, ...] + * input_shape: the shape of inputs + the [...] above, without batch size + * input_dtype: the data type of inputs + * target_shape: the shape of targets + the [...] above, without batch size + * target_dtype: the data type of targets + """ + + def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): + """Initialize a new set of inputs. + + Args: + train_stream: a function taking n_devices (an int) and returning + a python generator of training batches. + eval_stream: a function taking n_devices (an int) and returning + a python generator of validation batches; + if None, then the training generator will be used for evaluation. + train_eval_stream: a function taking n_devices (an int) and returning + a python generator of batches from + the training set used for evaluation (if None, use train_stream). + """ + if not callable(train_stream): + raise ValueError('Trax Inputs should be initialized with a function. ' + 'Did you forget the n_devices argument? If your inputs ' + 'do not use it, try lambda _: [your-inputs].') + + self._train_stream = train_stream + self._eval_stream = eval_stream or self._train_stream + + # TODO(lukaszkaiser): should we get rid of this one day? + self._train_eval_stream = train_eval_stream or self._train_stream + + # Peek into the train stream to get an example shape. + example_train_batch = next(train_stream(1)) + self._input_shape = tuple(example_train_batch[0].shape)[1:] + self._input_dtype = example_train_batch[0].dtype + self._target_shape = tuple(example_train_batch[-1].shape)[1:] + self._target_dtype = example_train_batch[-1].dtype + self._example_shape = [x.shape for x in example_train_batch] + self._example_dtype = [x.dtype for x in example_train_batch] + + def train_stream(self, n_devices): + return self._train_stream(n_devices) + + def eval_stream(self, n_devices): + return self._eval_stream(n_devices) + + def train_eval_stream(self, n_devices): + return self._train_stream(n_devices) + + @property + def input_shape(self): + """Example input shape, without batch dimension.""" + return self._input_shape + + @property + def target_shape(self): + """Example target shape, without batch dimension.""" + return self._target_shape + + @property + def input_dtype(self): + """Dtype of the input.""" + return self._input_dtype + + @property + def target_dtype(self): + """Dtype of the target.""" + return self._target_dtype + + @property + def example_shape_dtype(self): + """Shape and Dtype of an example batch.""" + return self._example_shape, self._example_dtype + + +# Batching and Inputs creation helpers. + + +@gin.configurable() +def make_inputs(train_stream=gin.REQUIRED, eval_stream=None): + """Create Inputs from two streams; mostly for use in gin configs.""" + if isinstance(train_stream, (list, tuple)): + train_stream = Serial(train_stream) + if isinstance(eval_stream, (list, tuple)): + eval_stream = Serial(eval_stream) + eval_stream_fn = None if eval_stream is None else lambda _: eval_stream + return Inputs(train_stream=lambda _: train_stream, + eval_stream=eval_stream_fn) + + +@gin.configurable() +def batcher(data_streams=gin.REQUIRED, variable_shapes=True, + batch_size_per_device=32, batch_size=None, eval_batch_size=32, + bucket_length=32, buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, max_eval_length=None, + # TODO(afrozm): Unify padding logic. + id_to_mask=None, strict_pad_on_len=False): + """Batcher: create trax Inputs from single-example data-streams.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # For now leaving the arguments as in batch_fn to reduce gin config changes. + if callable(data_streams): # If we pass a function, e.g., through gin, call. + train_stream, eval_stream = data_streams() + else: + train_stream, eval_stream = data_streams + # pylint: disable=g-long-lambda + batch_train_stream = lambda n_devices: batch_fn( + train_stream(), True, n_devices, variable_shapes, + batch_size_per_device, batch_size, eval_batch_size, + bucket_length, buckets, buckets_include_inputs_in_length, + batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) + batch_eval_stream = lambda n_devices: batch_fn( + eval_stream(), False, n_devices, variable_shapes, + batch_size_per_device, batch_size, eval_batch_size, + bucket_length, buckets, buckets_include_inputs_in_length, + batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) + batch_train_eval_stream = lambda n_devices: batch_fn( + train_stream(), False, n_devices, variable_shapes, + batch_size_per_device, batch_size, eval_batch_size, + bucket_length, buckets, buckets_include_inputs_in_length, + batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) + # pylint: enable=g-long-lambda + return Inputs(train_stream=batch_train_stream, + eval_stream=batch_eval_stream, + train_eval_stream=batch_train_eval_stream) + + +def batch_fn(dataset, training, n_devices, variable_shapes, + batch_size_per_device=32, batch_size=None, eval_batch_size=32, + bucket_length=32, buckets=None, + buckets_include_inputs_in_length=False, + batch_shuffle_size=None, max_eval_length=None, + id_to_mask=None, strict_pad_on_len=False): + """Batching function.""" + # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. + # After that, create a proper doc-string; we may also not need to pass both + # training and eval arguments here, as batcher calls the function separately + # now and it's not under gin-config any more -- consider reducing args. + batch_size = batch_size or batch_size_per_device * n_devices + # If bucketing is not specified, check if target shapes are variable. + cur_batch_size = batch_size if training else eval_batch_size + # Make cur_batch_size divisible by n_devices. + cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices + # Create heuristic buckets if none are specified. + if buckets is None: + logging.info('Heuristically setting bucketing to %s based on shapes ' + 'of target tensors.', variable_shapes) + if variable_shapes: + buckets = _buckets_for_length( + bucket_length, cur_batch_size, max_eval_length, n_devices, training) + + if buckets: + logging.info('Bucketing with buckets %s.', str(buckets)) + def example_length(x): + """The length function used by bucket_by_sequence_length to bucket.""" + # The input x is a tuple to go on the stack, typically either + # (input, target) or (input, target, mask). + example_inputs, target = x[0], x[1] + # Length is the shape of axis 0 here (no batch yet). + other_length = 0 # We include input length only if asked. + if buckets_include_inputs_in_length: + other_length = example_inputs.shape[0] + return max(target.shape[0], other_length) + boundaries, batch_sizes = buckets + dataset = bucket_by_length( + dataset, example_length, boundaries, batch_sizes, strict_pad_on_len) + else: + logging.info('Not Bucketing cur_batch_size %d.', cur_batch_size) + dataset = batch(dataset, cur_batch_size) + if training and batch_shuffle_size is not None: + dataset = shuffle(dataset, batch_shuffle_size) + return add_loss_weights(dataset, id_to_mask) + + # Example input functions. diff --git a/trax/data/inputs_test.py b/trax/data/inputs_test.py index 381c98de6..373a6ef45 100644 --- a/trax/data/inputs_test.py +++ b/trax/data/inputs_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized import numpy as np -from trax.data import inputs +from trax import data class InputsTest(parameterized.TestCase): @@ -31,7 +31,7 @@ class InputsTest(parameterized.TestCase): def test_shuffle_data_raises_error_queue_size(self, queue_size): samples = iter(range(10)) with self.assertRaises(ValueError): - _ = list(inputs.shuffle_data(samples, queue_size)) + _ = list(data.shuffle(samples, queue_size)) @parameterized.named_parameters( ('one', 1), @@ -40,7 +40,7 @@ def test_shuffle_data_raises_error_queue_size(self, queue_size): ) def test_shuffle_data_queue_size(self, queue_size): samples = iter(range(100, 200)) - shuffled_stream = inputs.shuffle_data(samples, queue_size) + shuffled_stream = data.shuffle(samples, queue_size) first_ten = [next(shuffled_stream) for _ in range(10)] # Queue size limits how far ahead/upstream the current sample can reach. @@ -63,33 +63,59 @@ def test_shuffle_data_queue_size(self, queue_size): ) def test_shuffle_data_yields_all_samples(self, queue_size, n_samples): samples = iter(range(n_samples)) - shuffled_stream = inputs.shuffle_data(samples, queue_size) + shuffled_stream = data.shuffle(samples, queue_size) self.assertLen(list(shuffled_stream), n_samples) def test_batch_data(self): dataset = ((i, i+1) for i in range(10)) - batches = inputs.batch_data(dataset, 10) + batches = data.batch(dataset, 10) batch = next(batches) self.assertLen(batch, 2) self.assertEqual(batch[0].shape, (10,)) + def test_serial(self): + dataset = lambda _: ((i, i+1) for i in range(10)) + batches = data.Serial(dataset, data.Shuffle(3), data.Batch(10)) + batch = next(batches) + self.assertLen(batch, 2) + self.assertEqual(batch[0].shape, (10,)) + + def test_serial_with_python(self): + dataset = lambda _: ((i, i+1) for i in range(10)) + batches = data.Serial( + dataset, + lambda g: map(lambda x: (x[0], x[1] + 1), g), + lambda g: filter(lambda x: x[0] % 2 == 1, g), + data.Batch(2) + ) + batch = next(batches) + self.assertLen(batch, 2) + (xs, ys) = batch + # First tuple after filtering is (1, 3) = (1, 2+1). + self.assertEqual(xs[0], 1) + self.assertEqual(ys[0], 3) + # Second tuple after filtering is (3, 5). + self.assertEqual(xs[1], 3) + self.assertEqual(ys[1], 5) + def test_pad_to_max_dims(self): tensors1 = [np.zeros((3, 10)), np.ones((3, 10))] - padded1 = inputs.pad_to_max_dims(tensors1) + padded1 = data.inputs.pad_to_max_dims(tensors1) self.assertEqual(padded1.shape, (2, 3, 10)) tensors2 = [np.zeros((2, 10)), np.ones((3, 9))] - padded2 = inputs.pad_to_max_dims(tensors2) + padded2 = data.inputs.pad_to_max_dims(tensors2) self.assertEqual(padded2.shape, (2, 3, 10)) tensors3 = [np.zeros((8, 10)), np.ones((8, 9))] - padded3 = inputs.pad_to_max_dims(tensors3, 12) - self.assertEqual(padded3.shape, (2, 8, 12)) + padded3 = data.inputs.pad_to_max_dims(tensors3, 12) + self.assertEqual(padded3.shape, (2, 12, 12)) tensors4 = [np.zeros((2, 10)), np.ones((3, 9))] - padded4 = inputs.pad_to_max_dims(tensors4, 12) + padded4 = data.inputs.pad_to_max_dims(tensors4, 12) self.assertEqual(padded4.shape, (2, 4, 12)) def test_pad_to_max_dims_boundary_list(self): tensors = [np.zeros((1, 15, 31)), np.ones((2, 10, 35)), np.ones((4, 2, 3))] - padded_tensors = inputs.pad_to_max_dims(tensors, boundary=(None, 15, 20)) + padded_tensors = data.inputs.pad_to_max_dims( + tensors, boundary=(None, 15, 20)) # no boundary, only max in the first dim, 15 is already the max len in # second dim, last dim padded to multiple of 20. # The outer dim is the batch here. @@ -97,7 +123,7 @@ def test_pad_to_max_dims_boundary_list(self): def test_pad_to_max_dims_strict_pad_on_len(self): tensors = [np.ones((15,)), np.ones((12,)), np.ones((14,))] - padded_tensors = inputs.pad_to_max_dims( + padded_tensors = data.inputs.pad_to_max_dims( tensors, boundary=10, strict_pad_on_len=True) self.assertEqual(padded_tensors.shape, (3, 20)) @@ -109,11 +135,11 @@ def fake_generator(length, num_examples=1): def length_function(example): return max(example[0].shape[0], example[1].shape[0]) - batches = list(inputs.bucket_by_length(fake_generator(5, 6), - length_function, - [20 + 1], - [2], - strict_pad_on_len=True)) + batches = list(data.bucket_by_length(fake_generator(5, 6), + length_function, + [20], + [2], + strict_pad_on_len=True)) # We'll get three batches of 2 examples each. self.assertLen(batches, 3) diff --git a/trax/data/tf_inputs.py b/trax/data/tf_inputs.py index 606137985..2c62e126e 100644 --- a/trax/data/tf_inputs.py +++ b/trax/data/tf_inputs.py @@ -226,6 +226,28 @@ def _train_and_eval_dataset(dataset_name, data_dir, eval_holdout_size, return train, valid, keys +@gin.configurable() +def TFDS(dataset_name, data_dir=None, # pylint: disable=invalid-name + keys=None, train=True, eval_holdout_size=0): + """Returns an iterator of numpy arrays representing the dataset.""" + data_dir = download_and_prepare(dataset_name, data_dir) + + (train_data, eval_data, _) = _train_and_eval_dataset( + dataset_name, data_dir, eval_holdout_size) + dataset = train_data if train else eval_data + + def select_from(example): + return tuple(example[k] for k in keys) + + dataset = dataset.map(select_from) + dataset = dataset.repeat() + + def gen(unused_arg): + for example in fastmath.dataset_as_numpy(dataset): + yield example + return gen + + def _select_features(example, feature_list=None): """Select a subset of features from the example dict.""" feature_list = feature_list or ['inputs', 'targets'] @@ -266,17 +288,17 @@ def _train_and_eval_dataset_v1(problem_name, data_dir, # Tokenization. -def tokenize(stream, indices=None, vocab_type='subword', +def tokenize(stream, keys=None, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0): """Tokenize examples from the stream. This function assumes that `stream` generates either strings or tuples/dicts - containing strings at some `indices`. This function maps these strings to + containing strings at some `keys`. This function maps these strings to numpy arrays of integers -- the tokenized version of each string. Args: stream: A python generator yielding strings, tuples or dicts. - indices: which indices of the tuple/dict to tokenize (by default: all) + keys: which keys of the tuple/dict to tokenize (by default: all) vocab_type: Type of vocabulary, one of: 'subword', 'sentencepiece', 'char'. vocab_file: Name of the vocabulary file. vocab_dir: Directory which contains the vocabulary file. @@ -286,7 +308,7 @@ def tokenize(stream, indices=None, vocab_type='subword', reserved) in the vocab_file. Yields: - Examples from stream with strings at `indices` replaced by np.arrays of + Examples from stream with strings at `keys` replaced by np.arrays of integers -- the tokenized version of these strings. """ vocab = _get_vocab(vocab_type, vocab_file, vocab_dir) @@ -294,7 +316,7 @@ def tokenize(stream, indices=None, vocab_type='subword', if isinstance(example, (list, tuple)): new_example = [] for i, x in enumerate(example): - if indices is None or i in indices: + if keys is None or i in keys: new_example.append(np.array(vocab.encode(x)) + n_reserved_ids) else: new_example.append(x) @@ -302,7 +324,7 @@ def tokenize(stream, indices=None, vocab_type='subword', elif isinstance(example, dict): new_example = {} for k in example.keys(): - if indices is None or k in indices: + if keys is None or k in keys: new_example[k] = np.array(vocab.encode(example[k])) + n_reserved_ids else: new_example[k] = example[k] @@ -311,6 +333,15 @@ def tokenize(stream, indices=None, vocab_type='subword', yield np.array(vocab.encode(example)) + n_reserved_ids +@gin.configurable() +def Tokenize(keys=None, vocab_type='subword', # pylint: disable=invalid-name + vocab_file=None, vocab_dir=None, n_reserved_ids=0): + """Returns a function that maps text to integer arrays; see `tokenize`.""" + return lambda g: tokenize( # pylint: disable=g-long-lambda + g, keys=keys, vocab_type=vocab_type, vocab_file=vocab_file, + vocab_dir=vocab_dir, n_reserved_ids=n_reserved_ids) + + def detokenize(x, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0): """Maps integer arrays to text; the opposite of `tokenize`. diff --git a/trax/data/tf_inputs_test.py b/trax/data/tf_inputs_test.py index 57e645c5c..4a923c167 100644 --- a/trax/data/tf_inputs_test.py +++ b/trax/data/tf_inputs_test.py @@ -141,7 +141,7 @@ def dataset(): vocab_dir=_TESTDATA, vocab_file='en_8k.subword') self.assertEqual(detok, 'I have a cat.') - def test_tokenize_indices_reservedids(self): + def test_tokenize_keys_reservedids(self): def dataset(): yield ('Cat.', 'Dog.') @@ -151,7 +151,7 @@ def dataset(): self.assertAllEqual(tok_char1[0][1], np.array([ord(c) + 5 for c in 'Dog.'])) tok_char2 = list(tf_inputs.tokenize( - dataset(), indices=[0], vocab_type='char', n_reserved_ids=2)) + dataset(), keys=[0], vocab_type='char', n_reserved_ids=2)) self.assertAllEqual(tok_char2[0][0], np.array([ord(c) + 2 for c in 'Cat.'])) self.assertEqual(tok_char2[0][1], 'Dog.') @@ -163,7 +163,7 @@ def dataset(): self.assertAllEqual(tok_char1[0]['a'], np.array([ord(c) for c in 'Cat.'])) self.assertAllEqual(tok_char1[0]['b'], np.array([ord(c) for c in 'Dog.'])) - tok_char2 = list(tf_inputs.tokenize(dataset(), indices=['a'], + tok_char2 = list(tf_inputs.tokenize(dataset(), keys=['a'], vocab_type='char')) self.assertAllEqual(tok_char2[0]['a'], np.array([ord(c) for c in 'Cat.'])) self.assertEqual(tok_char2[0]['b'], 'Dog.') diff --git a/trax/models/__init__.py b/trax/models/__init__.py index 4d54b9e96..e90ff5aae 100644 --- a/trax/models/__init__.py +++ b/trax/models/__init__.py @@ -14,10 +14,6 @@ # limitations under the License. """Models defined in trax.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - import gin from trax.models import atari_cnn diff --git a/trax/models/reformer/reformer_e2e_test.py b/trax/models/reformer/reformer_e2e_test.py index a222e4f67..e531e6d4b 100644 --- a/trax/models/reformer/reformer_e2e_test.py +++ b/trax/models/reformer/reformer_e2e_test.py @@ -78,7 +78,7 @@ def test_reformer_noencdecattn_wmt_ende(self): gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) - gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1. + gin.bind_parameter('batcher.buckets', ([512], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) diff --git a/trax/supervised/configs/transformer_imdb_8gb.gin b/trax/supervised/configs/transformer_imdb_8gb.gin index 41604008a..e29f0d3bd 100644 --- a/trax/supervised/configs/transformer_imdb_8gb.gin +++ b/trax/supervised/configs/transformer_imdb_8gb.gin @@ -14,49 +14,73 @@ import trax.models import trax.optimizers +import trax.data +import trax.data.inputs import trax.data.tf_inputs +import trax.supervised.lr_schedules import trax.supervised.trainer_lib -# Parameters for batcher: +# Parameters for the inputs pipeline: # ============================================================================== -batcher.data_streams = @tf_inputs.data_streams -batcher.batch_size_per_device = 128 -batcher.eval_batch_size = 128 -batcher.max_eval_length = 2048 +make_inputs.train_stream = [ + @train/data.TFDS(), + @data.Tokenize(), + @data.Shuffle(), + @train/data.FilterByLength(), + @data.BucketByLength(), + @data.AddLossWeights(), +] +train/data.TFDS.dataset_name = 'imdb_reviews' +train/data.TFDS.keys = ('text', 'label') +data.Tokenize.vocab_file = 'en_8k.subword' +data.Tokenize.keys = [0] # Tokenize only the first element (text, not label). +train/data.FilterByLength.max_length = 1024 +train/data.FilterByLength.length_keys = [0] +data.BucketByLength.boundaries = [32, 64, 128, 256, 512, 1024, 2048] +data.BucketByLength.batch_sizes = [128, 64, 32, 16, 8, 1, 1, 1] +data.BucketByLength.length_keys = [0] +make_inputs.eval_stream = [ + @eval/data.TFDS(), + @data.Tokenize(), + @data.Shuffle(), + @eval/data.FilterByLength(), + @data.BucketByLength(), + @data.AddLossWeights(), +] +eval/data.TFDS.dataset_name = 'imdb_reviews' +eval/data.TFDS.keys = ('text', 'label') +eval/data.TFDS.train = False +eval/data.FilterByLength.max_length = 2048 +eval/data.FilterByLength.length_keys = [0] # Parameters for data_streams: # ============================================================================== data_streams.data_dir = None data_streams.dataset_name = 't2t_sentiment_imdb' data_streams.input_name = 'targets' -data_streams.preprocess_fn = @trax.data.tf_inputs.lm1b_preprocess -# Parameters for multifactor: +# Parameters for warmup_and_rsqrt_decay: # ============================================================================== -multifactor.constant = 0.1 -multifactor.factors = 'constant * linear_warmup * rsqrt_decay' -multifactor.warmup_steps = 8000 - -# Parameters for lm1b_preprocess: -# ============================================================================== -lm1b_preprocess.max_target_length = 512 -lm1b_preprocess.max_eval_target_length = 2048 +lr_schedules.warmup_and_rsqrt_decay.max_value = 0.01 +lr_schedules.warmup_and_rsqrt_decay.n_warmup_steps = 1000 # Parameters for train: # ============================================================================== -train.eval_frequency = 100 -train.eval_steps = 10 +train.eval_frequency = 500 +train.eval_steps = 20 train.model = @trax.models.TransformerEncoder -train.steps = 1000 +train.steps = 10000 +train.inputs = @trax.data.inputs.make_inputs +train.lr_schedule_fn = @lr_schedules.warmup_and_rsqrt_decay -# Parameters for TransformerLM: +# Parameters for TransformerEncoder: # ============================================================================== TransformerEncoder.d_model = 512 TransformerEncoder.d_ff = 2048 TransformerEncoder.dropout = 0.1 TransformerEncoder.max_len = 2048 TransformerEncoder.mode = 'train' -TransformerEncoder.n_classes = 10 +TransformerEncoder.n_classes = 2 TransformerEncoder.n_heads = 8 TransformerEncoder.n_layers = 6 -TransformerEncoder.vocab_size = 32000 +TransformerEncoder.vocab_size = 8192 diff --git a/trax/supervised/training.py b/trax/supervised/training.py index d6ac05334..165fd9b6c 100644 --- a/trax/supervised/training.py +++ b/trax/supervised/training.py @@ -105,6 +105,9 @@ def __init__(self, model, tasks, eval_model=None, eval_tasks=None, eval_at: Function (integer --> boolean) that says, for training step n, whether that step should run evals. If None, run when checkpointing. """ + # Handle single task case without lists too. + if not isinstance(tasks, (list, tuple)): + tasks = [tasks] assert len(tasks) == 1, 'Multitask training not supported yet.' task = tasks[0] if eval_tasks is None: