diff --git a/AUTHORS b/AUTHORS index 38e5bc724..b4762f933 100644 --- a/AUTHORS +++ b/AUTHORS @@ -5,3 +5,4 @@ # of contributors, see the revision history in source control. Google Inc. +Artit Wangperawong \ No newline at end of file diff --git a/README.md b/README.md index 1a031086b..7e72ce13b 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ pip install tensor2tensor && t2t-trainer \ ### Contents * [Suggested Datasets and Models](#suggested-datasets-and-models) + * [Mathematical Language Understanding](#mathematical-language-understanding) * [Story, Question and Answer](#story-question-and-answer) * [Image Classification](#image-classification) * [Image Generation](#image-generation) @@ -79,6 +80,24 @@ hyperparameters that we know works well in our setup. We usually run either on Cloud TPUs or on 8-GPU machines; you might need to modify the hyperparameters if you run on a different setup. +### Mathematical Language Understanding + +For evaluating mathematical expressions at the character level involving addition, subtraction and multiplication of both positive and negative decimal numbers with variable digits assigned to symbolic variables, use + +* the [MLU](https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz) data-set: + `--problem=mathematical_language_understanding` + +You can try solving the problem with different transformer models and hyperparameters as described in the [paper](https://arxiv.org/abs/1812.02825): +* Standard transformer: +`--model=transformer` +`--hparams_set=transformer_tiny` +* Universal transformer: +`--model=universal_transformer` +`--hparams_set=universal_transformer_tiny` +* Adaptive universal transformer: +`--model=universal_transformer` +`--hparams_set=adaptive_universal_transformer_tiny` + ### Story, Question and Answer For answering questions based on a story, use @@ -464,5 +483,6 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research * [Fast Decoding in Sequence Models using Discrete Latent Variables](https://arxiv.org/abs/1803.03382) * [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) * [Universal Transformers](https://arxiv.org/abs/1807.03819) +* [Attending to Mathematical Language with Transformers](https://arxiv.org/abs/1812.02825) *Note: This is not an official Google product.* diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index b3f7263f1..688197ab0 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -50,6 +50,7 @@ "tensor2tensor.data_generators.lm1b", "tensor2tensor.data_generators.lm1b_imdb", "tensor2tensor.data_generators.lm1b_mnli", + "tensor2tensor.data_generators.mathematical_language_understanding", "tensor2tensor.data_generators.mnist", "tensor2tensor.data_generators.mrpc", "tensor2tensor.data_generators.mscoco", diff --git a/tensor2tensor/data_generators/babi_qa.py b/tensor2tensor/data_generators/babi_qa.py index a11eaddca..882d17778 100644 --- a/tensor2tensor/data_generators/babi_qa.py +++ b/tensor2tensor/data_generators/babi_qa.py @@ -109,9 +109,9 @@ def _prepare_babi_data(tmp_dir, data_dir): tf.gfile.MakeDirs(data_dir) file_path = os.path.join(tmp_dir, _TAR) - headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36"} # pylint: disable=line-too-long + headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'} resp = requests.get(_URL, headers=headers) - with open(file_path, "wb") as f: + with open(file_path, 'wb') as f: f.write(resp.content) tar = tarfile.open(file_path) @@ -459,7 +459,6 @@ def hparams(self, defaults, unused_model_hparams): if "context" in p.vocab_size: del p.vocab_size["context"] - def _problems_to_register(): """Problems for which we want to create datasets. diff --git a/tensor2tensor/data_generators/mathematical_language_understanding.py b/tensor2tensor/data_generators/mathematical_language_understanding.py new file mode 100644 index 000000000..53d4ddacb --- /dev/null +++ b/tensor2tensor/data_generators/mathematical_language_understanding.py @@ -0,0 +1,104 @@ +# coding=utf-8 +# Copyright 2018 Artit Wangperawong artitw@gmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Data generators for the Mathematical Language Understanding dataset. + +The training and test data were generated by assigning symbolic variables +either positive or negative decimal integers and then describing the algebraic +operation to perform. We restrict our variable assignments to the range +x,y->[-1000,1000) and the operations to the set {+,-,*}. To ensure that the +model embraces symbolic variables, the order in which x and y appears in the +expression is randomly chosen. For instance, an input string contrasting from +the example shown above might be y=129,x=531,x-y. Each input string is +accompanied by its target string, which is the evaluation of the mathematical +expression. For this study, all targets considered are decimal integers +represented at the character level. About 12 million unique samples were thus +generated and randomly split into training and test sets at an approximate +ratio of 9:1, respectively. + +For more information check the following paper: +Artit Wangperawong. Attending to Mathematical Language with Transformers, +arXiv:1812.02825. +Available at: https://arxiv.org/abs/1812.02825 + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_problems +from tensor2tensor.utils import registry + +import tensorflow as tf + +@registry.register_problem +class MathematicalLanguageUnderstanding(text_problems.Text2TextProblem): + URL = "https://art.wangperawong.com/mathematical_language_understanding_train.tar.gz" + + @property + def vocab_type(self): + return text_problems.VocabType.CHARACTER + + @property + def dataset_splits(self): + return [{ + "split": problem.DatasetSplit.TRAIN, + "shards": 10, + }, { + "split": problem.DatasetSplit.EVAL, + "shards": 1, + }] + + @property + def is_generate_per_split(self): + return False + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Downloads and extracts the dataset and generates examples + + Args: + tmp_dir: temp directory to download and extract the dataset + data_dir: The base directory where data and vocab files are stored. + + Returns: + data generator + """ + + if not tf.gfile.Exists(tmp_dir): + tf.gfile.MakeDirs(tmp_dir) + + if not tf.gfile.Exists(data_dir): + tf.gfile.MakeDirs(data_dir) + + # Download and extract + compressed_filename = os.path.basename(self.URL) + download_path = generator_utils.maybe_download(tmp_dir, compressed_filename, + self.URL) + + with tarfile.open(download_path, "r:gz") as tar: + tar.extractall(tmp_dir) + + filepath = os.path.join(tmp_dir, "mathematical_language_understanding_train.txt") + + with open(filepath, 'r') as fp: + for l in fp: + prob, ans = l.strip().split(':') + yield {"inputs": prob, "targets": ans} + diff --git a/tensor2tensor/models/research/universal_transformer.py b/tensor2tensor/models/research/universal_transformer.py index 788037d9f..11f743c76 100644 --- a/tensor2tensor/models/research/universal_transformer.py +++ b/tensor2tensor/models/research/universal_transformer.py @@ -243,8 +243,7 @@ def _greedy_infer(self, features, decode_length, use_tpu=False): return (self._slow_greedy_infer_tpu(features, decode_length) if use_tpu else self._slow_greedy_infer(features, decode_length)) - def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, - use_tpu=False): + def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, use_tpu=False): """Beam search decoding. Args: