Skip to content

Commit

Permalink
TF prepare_gradient_checkpointing, fix for newer TF (#1620)
Browse files Browse the repository at this point in the history
Fix #1616
  • Loading branch information
albertz committed Sep 6, 2024
1 parent 7a58317 commit 1b5530d
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions returnn/tf/util/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from __future__ import annotations

from typing import List, Tuple
import typing
from typing import List, Tuple, Dict
import contextlib
import weakref
import tensorflow as tf
Expand Down Expand Up @@ -72,7 +71,7 @@ def prepare_gradient_checkpointing():
"""
from tensorflow.python.framework import ops

copied_ops = {} # type: typing.Dict[tf.Operation, tf.Operation] # old -> new
copied_ops: Dict[Tuple[int, str], tf.Operation] = {} # graph id, old op name -> new op
copy_op_stack_depth = 0

class _DeepCopyError(Exception):
Expand All @@ -83,8 +82,8 @@ def __init__(self, op: tf.Operation):

# noinspection PyShadowingNames
def _copy_op(op: tf.Operation) -> tf.Operation:
if op in copied_ops:
return copied_ops[op]
if (id(op.graph), op.name) in copied_ops:
return copied_ops[(id(op.graph), op.name)]

nonlocal copy_op_stack_depth
if copy_op_stack_depth >= 1:
Expand Down Expand Up @@ -113,7 +112,7 @@ def _copy_op(op: tf.Operation) -> tf.Operation:
with tf_util.same_control_flow_ctx(op.outputs[0]), tf.name_scope(""):
new_op = tf_util.copy_op(op, inputs=new_inputs, name=op.name)
_set_wrapped_grad_func(new_op)
copied_ops[op] = new_op
copied_ops[(id(op.graph), op.name)] = new_op

assert op is copy_op_queue[-1]
copy_op_queue.pop(-1)
Expand Down

0 comments on commit 1b5530d

Please sign in to comment.