diff --git a/tensor2tensor/data_generators/text_problems.py b/tensor2tensor/data_generators/text_problems.py index 4067a65aa..f807e6f51 100644 --- a/tensor2tensor/data_generators/text_problems.py +++ b/tensor2tensor/data_generators/text_problems.py @@ -17,6 +17,7 @@ * Text2TextProblem: input=text, target=text. * Text2ClassProblem: input=text, target=class. +* Text2RealProblem: input=text, target=float. * Text2SelfProblem (for language modeling): target=text * QuestionAndContext2TextProblem: input=text, context=text, target=text. @@ -605,6 +606,94 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): yield {"inputs": inputs, "targets": [label]} +class Text2RealProblem(Text2TextProblem): + """Base class for text regression problems with one or more tasks. + Suitable for text-based problems where targets are continuous, real values. + When ntasks = 1, each text example is mapped to a single scalar value. When + ntasks > 1, each text example is mapped to a 1-d vector of length ntasks. + """ + + @property + def ntasks(self): + """Set to n > 1 for multitask regression.""" + return 1 + + def generate_samples(self, data_dir, tmp_dir, dataset_split): + """Generate samples of text and real-valued target pairs. + Each yielded dict will be a single example. The inputs should be raw text. + The target should be a list containing ntasks floats. + Args: + data_dir: final data directory. Typically only used in this method to copy + over user-supplied vocab files (for example, if vocab_type == + VocabType.TOKEN). + tmp_dir: temporary directory that you can use for downloading and scratch. + dataset_split: problem.DatasetSplit, which data split to generate samples + for (for example, training and evaluation). + Yields: + {"inputs": text, "targets": [x1, x2, ..., xN]} where N is ntasks + """ + raise NotImplementedError() + + def generate_text_for_vocab(self, data_dir, tmp_dir): + for i, sample in enumerate( + self.generate_samples(data_dir, tmp_dir, problem.DatasetSplit.TRAIN)): + yield sample["inputs"] + if self.max_samples_for_vocab and (i + 1) >= self.max_samples_for_vocab: + break + + def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split): + generator = self.generate_samples(data_dir, tmp_dir, dataset_split) + encoder = self.get_or_create_vocab(data_dir, tmp_dir) + for sample in generator: + inputs = encoder.encode(sample["inputs"]) + inputs.append(text_encoder.EOS_ID) + yield {"inputs": inputs, "targets": sample["targets"]} + + def feature_encoders(self, data_dir): + encoder = self.get_or_create_vocab(data_dir, None, force_get=True) + + return { + "inputs": encoder, + "targets": text_encoder.RealEncoder(), + } + + def hparams(self, defaults, unused_model_hparams): + p = defaults + p.modality = { + "inputs": modalities.ModalityType.SYMBOL, + "targets": modalities.ModalityType.REAL_L2_LOSS, + } + p.vocab_size = { + "inputs": self._encoders["inputs"].vocab_size, + "targets": self.ntasks + } + p.target_space_id = problem.SpaceID.REAL + p.add_hparam("regression_targets", True) + + def max_length(self, model_hparams): + return model_hparams.batch_size * self.ntasks + + def preprocess_example(self, example, unused_mode, unused_hparams): + example = problem.preprocess_example_common(example, unused_mode, + unused_hparams) + example["targets"] = tf.reshape(example["targets"], [1, 1, self.ntasks]) + return example + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + "targets": tf.FixedLenFeature([self.ntasks], tf.float32), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) + + def eval_metrics(self): + metrics_list = [metrics.Metrics.RMSE] + if self.ntasks == 1: + metrics_list.append(metrics.Metrics.PEARSON) + return metrics_list + + def txt_line_iterator(txt_path): """Iterate through lines of file.""" with tf.gfile.Open(txt_path) as f: @@ -692,6 +781,21 @@ def text2class_txt_iterator(source_txt_path, label_txt_path, class_strs=None): yield {"inputs": inputs, "label": label} +def text2real_txt_iterator(source_txt_path, target_txt_path): + """Yield dicts for Text2RealProblem.generate_samples from lines of files. + Args: + source_txt_path: txt file with record per line. + target_txt_path: txt file with float (or space-separated float list for + multitask) per line. + Yields: + {"inputs": inputs, "targets": targets} + """ + for inputs, targets in zip( + txt_line_iterator(source_txt_path), txt_line_iterator(target_txt_path)): + targets = [float(x) for x in targets.split(" ")] + yield {"inputs": inputs, "targets": targets} + + def text2text_txt_tab_iterator(txt_path): """Yield dicts for Text2TextProblem.generate_samples from lines of txt_path. diff --git a/tensor2tensor/data_generators/text_problems_test.py b/tensor2tensor/data_generators/text_problems_test.py index 51f948fcf..720f1ba68 100644 --- a/tensor2tensor/data_generators/text_problems_test.py +++ b/tensor2tensor/data_generators/text_problems_test.py @@ -94,6 +94,13 @@ def setUpClass(cls): tf.gfile.Copy(cls.targets_file, os.path.join(cls.tmp_dir, "targets.eval.txt")) + cls.targets_regr = [[1.23, 2.34], [4.56, 5.67]] + cls.targets_regr_file = os.path.join(cls.tmp_dir, "targets_regr.train.txt") + with tf.gfile.Open(cls.targets_regr_file, "w") as f: + for targets in cls.targets_regr: + f.write(" ".join([str(x) for x in targets]) + "\n") + + def testTxtLineIterator(self): lines = [line for line in text_problems.txt_line_iterator(self.inputs_file)] self.assertEqual(lines, self.inputs) @@ -136,6 +143,16 @@ def testText2ClassTxtIteratorWithStrs(self): self.assertEqual(inputs, self.inputs) self.assertEqual(labels, self.labels) + def testText2RealTxtIterator(self): + inputs = [] + targets = [] + for entry in text_problems.text2real_txt_iterator(self.inputs_file, + self.targets_regr_file): + inputs.append(entry["inputs"]) + targets.append(entry["targets"]) + self.assertEqual(inputs, self.inputs) + self.assertEqual(targets, self.targets_regr) + def testText2TextTxtTabIterator(self): inputs = [] targets = [] diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index eef24dba1..c1c75f121 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -462,7 +462,8 @@ def _fast_decode_tpu(self, if self.has_input: inputs_shape = common_layers.shape_list(features["inputs"]) - if target_modality == modalities.ModalityType.CLASS_LABEL: + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): decode_length = 1 else: decode_length = ( @@ -704,7 +705,8 @@ def _fast_decode(self, " of the dataset when decoding.") if self.has_input: inputs_shape = common_layers.shape_list(features["inputs"]) - if target_modality == modalities.ModalityType.CLASS_LABEL: + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): decode_length = 1 else: decode_length = ( diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index fe33e6315..2b9e419ef 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -806,8 +806,10 @@ def infer(self, if self._problem_hparams: target_modality = self._problem_hparams.modality["targets"] - if target_modality == modalities.ModalityType.CLASS_LABEL: - beam_size = 1 # No use to run beam-search for a single class. + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): + # No use to run beam-search for classification or regression. + beam_size = 1 if beam_size == 1: log_info("Greedy Decoding") results = self._greedy_infer(features, decode_length, use_tpu) @@ -1064,7 +1066,8 @@ def infer_step(i, recent_output, recent_logits, unused_loss): initial_output = tf.slice(initial_output, [0, 0, 0, 0], common_layers.shape_list(initial_output)) target_modality = self._problem_hparams.modality["targets"] - if target_modality == modalities.ModalityType.CLASS_LABEL: + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): decode_length = 1 else: if "partial_targets" in features: @@ -1243,7 +1246,8 @@ def infer_step(recent_output, recent_logits, unused_loss): initial_output = tf.slice(initial_output, [0, 0, 0, 0], common_layers.shape_list(initial_output)) target_modality = self._problem_hparams.modality["targets"] - if target_modality == modalities.ModalityType.CLASS_LABEL: + if (target_modality == modalities.ModalityType.CLASS_LABEL or + self._problem_hparams.get("regression_targets")): decode_length = 1 else: if "partial_targets" in features: