You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
The tf.function decorator on function per_field_where (in file tf_agents/utils/nest_utils.py) generates a tensorflow warning:
WARNING:tensorflow:5 out of the last 8 calls to <function where.<locals>.per_field_where at 0x7f07a6484b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
A simple python example could reproduce this issue:
import tensorflow as tf
import tf_agents
def test(condition, true_outputs, false_outputs):
return tf_agents.utils.nest_utils.where(condition, true_outputs, false_outputs)
if __name__ == "__main__":
for _ in range(10):
condition = tf.convert_to_tensor([True, True], dtype=tf.bool)
true_outputs = tf.convert_to_tensor([0, 1], dtype=tf.int32)
false_outputs = tf.convert_to_tensor([2, 3], dtype=tf.int32)
test(condition, true_outputs, false_outputs)
The issue seems related to the inner/local function definition, so it could be easily fixed by either:
removing the tf.function decorator
moving the per_field_where function outside of where function. Note we need to provide 'condition' and 'condition_rank' variables, so we could replace the statement with 'return tf.nest.map_structure(lambda t, f : _per_field_where(t, f, condition, condition_rank), true_outputs, false_outputs)'
Thanks
The text was updated successfully, but these errors were encountered:
Hello,
The tf.function decorator on function per_field_where (in file tf_agents/utils/nest_utils.py) generates a tensorflow warning:
WARNING:tensorflow:5 out of the last 8 calls to <function where.<locals>.per_field_where at 0x7f07a6484b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
A simple python example could reproduce this issue:
The issue seems related to the inner/local function definition, so it could be easily fixed by either:
Thanks
The text was updated successfully, but these errors were encountered: