@@ -431,14 +431,15 @@ class TensorBoard(Callback):
431
431
histograms for the layers of the model. If set to 0,
432
432
histograms won't be computed.
433
433
'''
434
- def __init__ (self , log_dir = './logs' , histogram_freq = 0 ):
434
+ def __init__ (self , log_dir = './logs' , histogram_freq = 0 , write_graph = True ):
435
435
super (Callback , self ).__init__ ()
436
436
if K ._BACKEND != 'tensorflow' :
437
437
raise Exception ('TensorBoard callback only works '
438
438
'with the TensorFlow backend.' )
439
439
self .log_dir = log_dir
440
440
self .histogram_freq = histogram_freq
441
441
self .merged = None
442
+ self .write_graph = write_graph
442
443
443
444
def _set_model (self , model ):
444
445
import tensorflow as tf
@@ -457,8 +458,16 @@ def _set_model(self, model):
457
458
tf .histogram_summary ('{}_out' .format (layer ),
458
459
layer .output )
459
460
self .merged = tf .merge_all_summaries ()
460
- self .writer = tf .train .SummaryWriter (self .log_dir ,
461
- self .sess .graph_def )
461
+ if self .write_graph :
462
+ tf_version = tuple (int (i ) for i in tf .__version__ .split ('.' ))
463
+ if tf_version >= (0 , 8 , 0 ):
464
+ self .writer = tf .train .SummaryWriter (self .log_dir ,
465
+ self .sess .graph )
466
+ else :
467
+ self .writer = tf .train .SummaryWriter (self .log_dir ,
468
+ self .sess .graph_def )
469
+ else :
470
+ self .writer = tf .train .SummaryWriter (self .log_dir )
462
471
463
472
def on_epoch_end (self , epoch , logs = {}):
464
473
import tensorflow as tf
0 commit comments