Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend CheckpointFunction to track all tensor input/output #1148

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 105 additions & 37 deletions fairscale/nn/checkpoint/checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,77 @@ def checkpoint_wrapper(
return module


def dfs_simplified(entity):
"""
a helper function that takes a python container (tuple, list, dict) and replace
any tensor with its shape; the main purpose is for printing and debugging
"""
if isinstance(entity, tuple):
return tuple(dfs_simplified(value) for value in entity)
elif isinstance(entity, list):
return [dfs_simplified(value) for value in entity]
elif isinstance(entity, dict):
return {key: dfs_simplified(value) for key, value in entity.items()}
elif isinstance(entity, torch.Tensor):
return entity.shape
else:
return entity


SimpleEntity = collections.namedtuple("SimpleEntity", ["is_tensor", "value"])


def serialize_tensors(inputs: Any) -> Tuple[Tuple[torch.Tensor], Any]:
"""
given a python container inputs (tuple, list, dict), which may contain tensors
this function extract the tensors in the container as a tuple, while returning
another container with the tensors replaced with the indices in the tuple
"""
tensors = []

def dfs(entity):
if isinstance(entity, tuple):
return tuple(dfs(value) for value in entity)
elif isinstance(entity, list):
return [dfs(value) for value in entity]
elif isinstance(entity, dict):
return {key: dfs(value) for key, value in entity.items()}
elif isinstance(entity, torch.Tensor):
tensors.append(entity)
return SimpleEntity(True, len(tensors)-1)
else:
return SimpleEntity(False, entity)

non_tensors = dfs(inputs)

return tuple(tensors), non_tensors


def deserialize_tensors(tensors: Tuple[torch.Tensor], non_tensors: Any) -> Any:
"""
the reverse function of the serialize_tensors, given a tuple of tensors and
a container with tensor index, it returns a container with the tensor index
replaced with the corresponding tensor
"""
def dfs(entity):
# check SimpleEntity first, since it is a subclass of Tuple
if isinstance(entity, SimpleEntity):
if entity.is_tensor:
return tensors[entity.value]
else:
return entity.value
elif isinstance(entity, tuple):
return tuple(dfs(value) for value in entity)
elif isinstance(entity, list):
return [dfs(value) for value in entity]
elif isinstance(entity, dict):
return {key: dfs(value) for key, value in entity.items()}
else:
raise RuntimeError(f"Unexpected type {type(entity)}")

return dfs(non_tensors)


def _checkpointed_forward(
original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
Expand All @@ -173,8 +244,8 @@ def _checkpointed_forward(
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
args = (module,) + args
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
tensor_inputs, non_tensor_inputs = serialize_tensors((module, args, kwargs))

parent_ctx_dict: Dict[str, Any] = {
"offload": offload_to_cpu,
}
Expand All @@ -189,7 +260,7 @@ def _checkpointed_forward(
# We get around this by saving the desired requires_grad value in output and
# detaching the output if needed.
output = CheckpointFunction.apply(
torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, non_tensor_inputs, *tensor_inputs
)
output_requires_grad = parent_ctx_dict["output_requires_grad"]
if not isinstance(output, torch.Tensor):
Expand All @@ -198,10 +269,9 @@ def _checkpointed_forward(
# requires_grad
output = [x.detach() if not output_requires_grad else x for x in output]

packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)

non_tensor_outputs = parent_ctx_dict["non_tensor_outputs"]
if non_tensor_outputs:
output = deserialize_tensors(output, non_tensor_outputs)
else:
# If output should not require grad, then detach it, since otherwise it will
# always have requires_grad = True due to our dummy tensor input above that
Expand Down Expand Up @@ -256,32 +326,28 @@ def forward( # type: ignore
dummy_tensor_requires_grad: torch.Tensor,
run_function: Any,
parent_ctx_dict: Dict[str, Any],
kwarg_keys: Tuple[str, ...],
*args: Any,
**kwargs: Any
non_tensor_inputs: Tuple[Any],
*tensor_inputs: torch.Tensor,
) -> Any:
torch_checkpoint.check_backward_validity(args)
torch_checkpoint.check_backward_validity(tensor_inputs)

ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
ctx.non_tensor_inputs = non_tensor_inputs
ctx.fwd_rng_state = get_rng_state()
ctx.had_autocast_in_fwd = is_autocast_enabled()

tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs)
ctx.save_for_backward(*(x.to("cpu", non_blocking=True) for x in tensor_inputs))
else:
ctx.fwd_device, ctx.grad_requirements = None, None

ctx.save_for_backward(*tensor_inputs)
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
ctx.fwd_device = None
ctx.grad_requirements = None
ctx.save_for_backward(*tensor_inputs)

with torch.no_grad(), enable_checkpointing():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
the_module, args, kwargs = deserialize_tensors(tensor_inputs, non_tensor_inputs)
outputs = run_function(the_module, *args, **kwargs)

# Because we run with torch.no_grad(), we can't actually access
# outputs.requires_grad. Instead, we manually compute it by
Expand All @@ -303,13 +369,14 @@ def forward( # type: ignore
# Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs

return outputs
tensor_outputs, non_tensor_outputs = serialize_tensors(outputs)
parent_ctx_dict["non_tensor_outputs"] = non_tensor_outputs
return tensor_outputs
else:
return outputs

@staticmethod
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
def backward(ctx: Any, *grad_outputs: Any) -> Tuple[Optional[Tensor], ...]:
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")

Expand All @@ -319,7 +386,7 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs))
for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
non_tensor_inputs = ctx.non_tensor_inputs

# Store the current states.
bwd_rng_state = get_rng_state()
Expand All @@ -328,26 +395,27 @@ def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
set_rng_state(ctx.fwd_rng_state)

with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
the_module, args, kwargs = deserialize_tensors(tensor_inputs, non_tensor_inputs)
outputs = ctx.run_function(the_module, *args, **kwargs)
tensor_outputs, _ = serialize_tensors(outputs)

# Set the states back to what it was at the start of this function.
set_rng_state(bwd_rng_state)

# Run backward() with only Tensors that require grad
outputs_with_grad = []
args_with_grad = []
assert len(tensor_outputs) == len(grad_outputs)
tensor_outputs_with_grad = []
grad_outputs_with_grad = []
for i in range(len(tensor_outputs)):
if tensor_outputs[i].requires_grad:
outputs_with_grad.append(tensor_outputs[i])
args_with_grad.append(args[i])
tensor_outputs_with_grad.append(tensor_outputs[i])
grad_outputs_with_grad.append(grad_outputs[i])

if len(outputs_with_grad) == 0:
if len(tensor_outputs_with_grad) == 0:
raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")

torch.autograd.backward(outputs_with_grad, args_with_grad)
torch.autograd.backward(tensor_outputs_with_grad, grad_outputs_with_grad)

grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
grads = tuple(inp.grad for inp in tensor_inputs)

return (None, None, None, None) + grads
Loading