Skip to content

Commit e04ce5e

Browse files
committed
Fix TensorBoard callback warning
1 parent 33af75a commit e04ce5e

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

keras/callbacks.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,15 @@ class TensorBoard(Callback):
431431
histograms for the layers of the model. If set to 0,
432432
histograms won't be computed.
433433
'''
434-
def __init__(self, log_dir='./logs', histogram_freq=0):
434+
def __init__(self, log_dir='./logs', histogram_freq=0, write_graph=True):
435435
super(Callback, self).__init__()
436436
if K._BACKEND != 'tensorflow':
437437
raise Exception('TensorBoard callback only works '
438438
'with the TensorFlow backend.')
439439
self.log_dir = log_dir
440440
self.histogram_freq = histogram_freq
441441
self.merged = None
442+
self.write_graph = write_graph
442443

443444
def _set_model(self, model):
444445
import tensorflow as tf
@@ -457,8 +458,16 @@ def _set_model(self, model):
457458
tf.histogram_summary('{}_out'.format(layer),
458459
layer.output)
459460
self.merged = tf.merge_all_summaries()
460-
self.writer = tf.train.SummaryWriter(self.log_dir,
461-
self.sess.graph_def)
461+
if self.write_graph:
462+
tf_version = tuple(int(i) for i in tf.__version__.split('.'))
463+
if tf_version >= (0, 8, 0):
464+
self.writer = tf.train.SummaryWriter(self.log_dir,
465+
self.sess.graph)
466+
else:
467+
self.writer = tf.train.SummaryWriter(self.log_dir,
468+
self.sess.graph_def)
469+
else:
470+
self.writer = tf.train.SummaryWriter(self.log_dir)
462471

463472
def on_epoch_end(self, epoch, logs={}):
464473
import tensorflow as tf

0 commit comments

Comments
 (0)