diff --git a/text/augmentation/sent_level_augment.py b/text/augmentation/sent_level_augment.py index 44f8827..e1cfb9c 100644 --- a/text/augmentation/sent_level_augment.py +++ b/text/augmentation/sent_level_augment.py @@ -22,6 +22,7 @@ import math import random +import six from absl import flags import numpy as np @@ -117,7 +118,7 @@ def back_translation(examples, aug_ops, sub_set, aug_copy_num, text_b=text_b, label=ori_example.label) aug_examples += [example] - if np.random.random() < 0.0001: + if six.PY2 and np.random.random() < 0.0001: tf.logging.info("\tori:\n\t\t{:s}\n\t\t{:s}\n\t\t{:s}\n".format( ori_example.text_a, ori_example.text_b, ori_example.label)) tf.logging.info("\tnew:\n\t\t{:s}\n\t\t{:s}\n\t\t{:s}\n".format( diff --git a/text/bert/modeling.py b/text/bert/modeling.py index b0fa8bb..d8553de 100644 --- a/text/bert/modeling.py +++ b/text/bert/modeling.py @@ -312,7 +312,10 @@ def get_activation(activation_string): # We assume that anything that's not a string is already an activation # function, so we just return it. - if not isinstance(activation_string, (str, unicode)): + + if six.PY2 and not isinstance(activation_string, (str, unicode)): + return activation_string + elif six.PY3 and not isinstance(activation_string, str): return activation_string if not activation_string: @@ -964,7 +967,9 @@ def assert_rank(tensor, expected_rank, name=None): name = tensor.name expected_rank_dict = {} - if isinstance(expected_rank, (int, long)): + if six.PY2 and isinstance(expected_rank, (int, long)): + expected_rank_dict[expected_rank] = True + elif six.PY3 and isinstance(expected_rank, int): expected_rank_dict[expected_rank] = True else: for x in expected_rank: diff --git a/text/preprocess.py b/text/preprocess.py index ef5aab7..2f10bb0 100644 --- a/text/preprocess.py +++ b/text/preprocess.py @@ -21,6 +21,7 @@ import copy import json import os +import six from absl import app from absl import flags @@ -266,7 +267,7 @@ def convert_examples_to_features( # st = " ".join([str(x) for x in tokens]) st = "" for x in tokens: - if isinstance(x, unicode): + if six.PY2 and isinstance(x, unicode): st += x.encode("ascii", "replace") + " " else: st += str(x) + " " diff --git a/text/utils/tokenization.py b/text/utils/tokenization.py index e0369ac..3450331 100644 --- a/text/utils/tokenization.py +++ b/text/utils/tokenization.py @@ -34,14 +34,19 @@ def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() index = 0 - with open_reader(vocab_file) as reader: - while True: - token = reader.readline() - if not token: - break - token = token.strip() - vocab[token] = index - index += 1 + if six.PY2: + reader = open_reader(vocab_file) + else: + reader = tf.gfile.GFile(vocab_file, "r") + + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + reader.close() return vocab @@ -265,11 +270,12 @@ def _is_punctuation(char): def _convert_to_unicode_or_throw(text): """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if isinstance(text, str): - text = text.decode("utf-8", "ignore") - if not isinstance(text, unicode): - raise ValueError("`text` must be of type `unicode` or `str`, but is " - "actually of type: %s" % (type(text).__name__)) + if six.PY2: + if isinstance(text, str): + text = text.decode("utf-8", "ignore") + if not isinstance(text, unicode): + raise ValueError("`text` must be of type `unicode` or `str`, but is " + "actually of type: %s" % (type(text).__name__)) return text