Skip to content

Commit

Permalink
TF prepare_gradient_checkpointing, avoid deep recursion (#1619)
Browse files Browse the repository at this point in the history
Fix #1616
  • Loading branch information
albertz authored Sep 6, 2024
1 parent 94e8e1d commit 7a58317
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions returnn/tf/util/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,53 @@ def prepare_gradient_checkpointing():
from tensorflow.python.framework import ops

copied_ops = {} # type: typing.Dict[tf.Operation, tf.Operation] # old -> new
copy_op_stack_depth = 0

class _DeepCopyError(Exception):
# noinspection PyShadowingNames
def __init__(self, op: tf.Operation):
super().__init__(f"deep copy err: {op}")
self.op = op

# noinspection PyShadowingNames
def _copy_op(op: tf.Operation) -> tf.Operation:
if op in copied_ops:
return copied_ops[op]

new_inputs = []
for x in op.inputs:
x = _map_tensor(x)
new_inputs.append(x)
nonlocal copy_op_stack_depth
if copy_op_stack_depth >= 1:
# Avoid deep recursions here, as this can get very deep on big graphs.
# So do a flat construction in the loop below.
raise _DeepCopyError(op)

try:
copy_op_stack_depth += 1

new_op = None
copy_op_queue = [op]
while copy_op_queue:
op = copy_op_queue[-1]

try:
new_inputs = []
for x in op.inputs:
x = _map_tensor(x)
new_inputs.append(x)

except _DeepCopyError as exc:
copy_op_queue.append(exc.op)
continue

with tf_util.same_control_flow_ctx(op.outputs[0]), tf.name_scope(""):
new_op = tf_util.copy_op(op, inputs=new_inputs, name=op.name)
_set_wrapped_grad_func(new_op)
copied_ops[op] = new_op

assert op is copy_op_queue[-1]
copy_op_queue.pop(-1)

with tf_util.same_control_flow_ctx(op.outputs[0]), tf.name_scope(""):
new_op = tf_util.copy_op(op, inputs=new_inputs, name=op.name)
_set_wrapped_grad_func(new_op)
copied_ops[op] = new_op
finally:
copy_op_stack_depth -= 1
return new_op

def _map_tensor(x: tf.Tensor) -> tf.Tensor:
Expand Down

0 comments on commit 7a58317

Please sign in to comment.