diff --git a/tensor2tensor/data_generators/wikisum/utils_test.py b/tensor2tensor/data_generators/wikisum/utils_test.py index 36397bae6..e713d1938 100644 --- a/tensor2tensor/data_generators/wikisum/utils_test.py +++ b/tensor2tensor/data_generators/wikisum/utils_test.py @@ -24,23 +24,24 @@ import tensorflow as tf -pkg_dir, _ = os.path.split(__file__) +pkg_dir = os.path.abspath(__file__) +pkg_dir, _ = os.path.split(pkg_dir) _TESTDATA = os.path.join(pkg_dir, "test_data") def _get_testdata(filename): - with tf.gfile.Open(os.path.join(_TESTDATA, filename)) as f: + with tf.io.gfile.GFile(filename) as f: return f.read() class UtilsTest(tf.test.TestCase): def test_filter_paragraph(self): - for bad in tf.gfile.Glob(os.path.join(_TESTDATA, "para_bad*.txt")): + for bad in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_bad*.txt")): for p in _get_testdata(bad).split("\n"): self.assertTrue(utils.filter_paragraph(p), msg="Didn't filter %s" % p) - for good in tf.gfile.Glob(os.path.join(_TESTDATA, "para_good*.txt")): + for good in tf.io.gfile.glob(os.path.join(_TESTDATA, "para_good*.txt")): for p in _get_testdata(good).split("\n"): p = _get_testdata(good) self.assertFalse(utils.filter_paragraph(p), msg="Filtered %s" % p)