diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index cd96d6350..81878134f 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -2017,7 +2017,7 @@ def summarize_features(features, num_shards=1): with tf.name_scope("input_stats"): for (k, v) in sorted(six.iteritems(features)): - if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: + if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1 and v.dtype != tf.string: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // num_shards) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0))