From ea748717de594d5e60c9f4afa93f9d40d3c244ad Mon Sep 17 00:00:00 2001 From: cclauss Date: Fri, 21 Dec 2018 13:31:33 +0100 Subject: [PATCH] Add to_unicode_utf8() to text_encoder.py --- tensor2tensor/data_generators/cnn_dailymail.py | 6 +----- tensor2tensor/data_generators/cola.py | 6 +----- tensor2tensor/data_generators/mrpc.py | 6 +----- tensor2tensor/data_generators/multinli.py | 6 +----- tensor2tensor/data_generators/qnli.py | 6 +----- tensor2tensor/data_generators/quora_qpairs.py | 6 +----- tensor2tensor/data_generators/rte.py | 6 +----- tensor2tensor/data_generators/scitail.py | 6 +----- tensor2tensor/data_generators/sst_binary.py | 6 +----- tensor2tensor/data_generators/stanford_nli.py | 6 +----- tensor2tensor/data_generators/text_encoder.py | 4 ++++ tensor2tensor/data_generators/wiki_revision_utils.py | 10 ++-------- tensor2tensor/data_generators/wnli.py | 6 +----- 13 files changed, 17 insertions(+), 63 deletions(-) diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 8da272526..16f0a678d 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -24,7 +24,6 @@ import os import random import tarfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -157,10 +156,7 @@ def fix_run_on_sents(line): summary = [] reading_highlights = False for line in tf.gfile.Open(story_file, "rb"): - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) line = fix_run_on_sents(line) if not line: continue diff --git a/tensor2tensor/data_generators/cola.py b/tensor2tensor/data_generators/cola.py index 7a905573c..0f2748fea 100644 --- a/tensor2tensor/data_generators/cola.py +++ b/tensor2tensor/data_generators/cola.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -83,10 +82,7 @@ def _maybe_download_corpora(self, tmp_dir): def example_generator(self, filename): for line in tf.gfile.Open(filename, "rb"): - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) _, label, _, sent = line.split("\t") yield { "inputs": sent, diff --git a/tensor2tensor/data_generators/mrpc.py b/tensor2tensor/data_generators/mrpc.py index e8c9e3a39..47c8364d2 100644 --- a/tensor2tensor/data_generators/mrpc.py +++ b/tensor2tensor/data_generators/mrpc.py @@ -20,7 +20,6 @@ from __future__ import print_function import os -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -95,10 +94,7 @@ def download_file(tdir, filepath, url): def example_generator(self, filename, dev_ids, dataset_split): for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) l, id1, id2, s1, s2 = line.split("\t") is_dev = [id1, id2] in dev_ids if dataset_split == problem.DatasetSplit.TRAIN and is_dev: diff --git a/tensor2tensor/data_generators/multinli.py b/tensor2tensor/data_generators/multinli.py index 70ee0107f..e3af79f0e 100644 --- a/tensor2tensor/data_generators/multinli.py +++ b/tensor2tensor/data_generators/multinli.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import problem @@ -87,10 +86,7 @@ def example_generator(self, filename): label_list = self.class_labels(data_dir=None) for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) split_line = line.split("\t") # Works for both splits even though dev has some extra human labels. s1, s2 = split_line[8:10] diff --git a/tensor2tensor/data_generators/qnli.py b/tensor2tensor/data_generators/qnli.py index b59db970f..9eeeb2077 100644 --- a/tensor2tensor/data_generators/qnli.py +++ b/tensor2tensor/data_generators/qnli.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -85,10 +84,7 @@ def example_generator(self, filename): label_list = self.class_labels(data_dir=None) for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) _, s1, s2, l = line.split("\t") inputs = [s1, s2] l = label_list.index(l) diff --git a/tensor2tensor/data_generators/quora_qpairs.py b/tensor2tensor/data_generators/quora_qpairs.py index 5960c2488..21f3702ec 100644 --- a/tensor2tensor/data_generators/quora_qpairs.py +++ b/tensor2tensor/data_generators/quora_qpairs.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -84,10 +83,7 @@ def example_generator(self, filename): skipped = 0 for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) split_line = line.split("\t") if len(split_line) < 6: skipped += 1 diff --git a/tensor2tensor/data_generators/rte.py b/tensor2tensor/data_generators/rte.py index af7fa41e9..2eff16422 100644 --- a/tensor2tensor/data_generators/rte.py +++ b/tensor2tensor/data_generators/rte.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -85,10 +84,7 @@ def example_generator(self, filename): label_list = self.class_labels(data_dir=None) for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) _, s1, s2, l = line.split("\t") inputs = [s1, s2] l = label_list.index(l) diff --git a/tensor2tensor/data_generators/scitail.py b/tensor2tensor/data_generators/scitail.py index 90df97ccb..f600fa560 100644 --- a/tensor2tensor/data_generators/scitail.py +++ b/tensor2tensor/data_generators/scitail.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import problem @@ -83,10 +82,7 @@ def _maybe_download_corpora(self, tmp_dir): def example_generator(self, filename): label_list = self.class_labels(data_dir=None) for line in tf.gfile.Open(filename, "rb"): - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) split_line = line.split("\t") s1, s2 = split_line[:2] l = label_list.index(split_line[2]) diff --git a/tensor2tensor/data_generators/sst_binary.py b/tensor2tensor/data_generators/sst_binary.py index a8a391c95..9081fc81a 100644 --- a/tensor2tensor/data_generators/sst_binary.py +++ b/tensor2tensor/data_generators/sst_binary.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -84,10 +83,7 @@ def _maybe_download_corpora(self, tmp_dir): def example_generator(self, filename): for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) sent, label = line.split("\t") yield { "inputs": sent, diff --git a/tensor2tensor/data_generators/stanford_nli.py b/tensor2tensor/data_generators/stanford_nli.py index a8aa04602..9c99501ec 100644 --- a/tensor2tensor/data_generators/stanford_nli.py +++ b/tensor2tensor/data_generators/stanford_nli.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import problem @@ -84,10 +83,7 @@ def example_generator(self, filename): label_list = self.class_labels(data_dir=None) for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) split_line = line.split("\t") # Works for both splits even though dev has some extra human labels. s1, s2 = split_line[5:7] diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index 5580a2f22..3bfa01c9c 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -98,6 +98,10 @@ def to_unicode_ignore_errors(s): return to_unicode(s, ignore_errors=True) +def to_unicode_utf8(s): + return unicode(s, "utf-8") if six.PY2 else s.decode("utf-8") + + def strip_ids(ids, ids_to_strip): """Strip ids_to_strip from the end ids.""" ids = list(ids) diff --git a/tensor2tensor/data_generators/wiki_revision_utils.py b/tensor2tensor/data_generators/wiki_revision_utils.py index 027bd162b..9704c068c 100644 --- a/tensor2tensor/data_generators/wiki_revision_utils.py +++ b/tensor2tensor/data_generators/wiki_revision_utils.py @@ -27,18 +27,12 @@ import re import subprocess -import six - from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder import tensorflow as tf -def to_unicode(s): - return unicode(s, "utf-8") if six.PY2 else s.decode("utf-8") - - def include_revision(revision_num, skip_factor=1.1): """Decide whether to include a revision. @@ -118,7 +112,7 @@ def get_title(page): assert start_pos != -1 assert end_pos != -1 start_pos += len("") - return to_unicode(page[start_pos:end_pos]) + return text_encoder.to_unicode_utf8(page[start_pos:end_pos]) def get_id(page): @@ -257,7 +251,7 @@ def get_text(revision, strip=True): ret = revision[end_tag_pos:end_pos] if strip: ret = strip_text(ret) - ret = to_unicode(ret) + ret = text_encoder.to_unicode_utf8(ret) return ret diff --git a/tensor2tensor/data_generators/wnli.py b/tensor2tensor/data_generators/wnli.py index cd4de046d..9b94b5b43 100644 --- a/tensor2tensor/data_generators/wnli.py +++ b/tensor2tensor/data_generators/wnli.py @@ -21,7 +21,6 @@ import os import zipfile -import six from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder @@ -88,10 +87,7 @@ def _maybe_download_corpora(self, tmp_dir): def example_generator(self, filename): for idx, line in enumerate(tf.gfile.Open(filename, "rb")): if idx == 0: continue # skip header - if six.PY2: - line = unicode(line.strip(), "utf-8") - else: - line = line.strip().decode("utf-8") + line = text_encoder.to_unicode_utf8(line.strip()) _, s1, s2, l = line.split("\t") inputs = [s1, s2] yield {