From e44f905185cb29ed0cd11ff7f54f61c3cf66f80e Mon Sep 17 00:00:00 2001 From: Tanguy Urvoy Date: Fri, 23 Aug 2019 17:37:47 +0200 Subject: [PATCH] Update generator_utils.py Hi, `isinstance(v[0], six.integer_types)` is False for `np.int64` type causing algorithmic_sort_problem data generation to fail. --- tensor2tensor/data_generators/generator_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 43d4ff14d..4773db565 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -38,6 +38,7 @@ from tensor2tensor.utils import mlperf_log import tensorflow as tf +import numpy as np UNSHUFFLED_SUFFIX = "-unshuffled" @@ -48,7 +49,7 @@ def to_example(dictionary): for (k, v) in six.iteritems(dictionary): if not v: raise ValueError("Empty generated field: %s" % str((k, v))) - if isinstance(v[0], six.integer_types): + if isinstance(v[0], six.integer_types) or isinstance(v[0], np.int64): features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) elif isinstance(v[0], float): features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))