diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 7b38d6c87..b8c029cfd 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -1745,7 +1745,7 @@ def dot_product_attention_relative(q, save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = logits weights = tf.nn.dropout(weights, 1.0 - dropout_rate) - if not tf.get_variable_scope().reuse and make_image_summary: + if not tf.get_variable_scope().reuse and common_layers.should_generate_summaries() and make_image_summary: attention_image_summary(weights, image_shapes) return _relative_attention_inner(weights, v, relations_values, False)