From 1b5530d4da28b4c12772f5256799e359d5f737ce Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 6 Sep 2024 23:45:38 +0200 Subject: [PATCH] TF prepare_gradient_checkpointing, fix for newer TF (#1620) Fix #1616 --- returnn/tf/util/gradient_checkpoint.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/returnn/tf/util/gradient_checkpoint.py b/returnn/tf/util/gradient_checkpoint.py index 18ce76f35..72d445db6 100644 --- a/returnn/tf/util/gradient_checkpoint.py +++ b/returnn/tf/util/gradient_checkpoint.py @@ -4,8 +4,7 @@ from __future__ import annotations -from typing import List, Tuple -import typing +from typing import List, Tuple, Dict import contextlib import weakref import tensorflow as tf @@ -72,7 +71,7 @@ def prepare_gradient_checkpointing(): """ from tensorflow.python.framework import ops - copied_ops = {} # type: typing.Dict[tf.Operation, tf.Operation] # old -> new + copied_ops: Dict[Tuple[int, str], tf.Operation] = {} # graph id, old op name -> new op copy_op_stack_depth = 0 class _DeepCopyError(Exception): @@ -83,8 +82,8 @@ def __init__(self, op: tf.Operation): # noinspection PyShadowingNames def _copy_op(op: tf.Operation) -> tf.Operation: - if op in copied_ops: - return copied_ops[op] + if (id(op.graph), op.name) in copied_ops: + return copied_ops[(id(op.graph), op.name)] nonlocal copy_op_stack_depth if copy_op_stack_depth >= 1: @@ -113,7 +112,7 @@ def _copy_op(op: tf.Operation) -> tf.Operation: with tf_util.same_control_flow_ctx(op.outputs[0]), tf.name_scope(""): new_op = tf_util.copy_op(op, inputs=new_inputs, name=op.name) _set_wrapped_grad_func(new_op) - copied_ops[op] = new_op + copied_ops[(id(op.graph), op.name)] = new_op assert op is copy_op_queue[-1] copy_op_queue.pop(-1)