diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 37fff943d6..f4b1fb23ae 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -451,11 +451,12 @@ def hook_fn( if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) with _none_grad_context_wrapper(inputs): + outputs_requiring_grad = tuple( + o for o in outputs if o is not None and o.requires_grad + ) torch.autograd.backward( - tuple(o for o in outputs if o.requires_grad), - grad_tensors=tuple( - torch.empty_like(o) for o in outputs if o.requires_grad - ), + outputs_requiring_grad, + grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), ) grad_inputs = tuple(input.grad for input in inputs) @@ -616,19 +617,22 @@ def hook_fn( # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. static_grad_outputs_keys = tuple( - (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + (o.shape, o.dtype, o.layout) + for o in static_outputs + if o is not None and o.requires_grad ) if static_grad_outputs_keys in static_grad_outputs_dict: static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None + torch.empty_like(o) if o is not None and o.requires_grad else None for o in static_outputs ) static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs else: static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -636,7 +640,9 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple( + o for o in static_outputs if o is not None and o.requires_grad + ), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -719,7 +725,8 @@ def hook_fn( ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( - torch.empty_like(o) if o.requires_grad else None for o in static_outputs + torch.empty_like(o) if o is not None and o.requires_grad else None + for o in static_outputs ) if is_training: inputs = tuple(i for i in static_input_surface if i.requires_grad) @@ -727,7 +734,7 @@ def hook_fn( bwd_graph, pool=mempool ): torch.autograd.backward( - tuple(o for o in static_outputs if o.requires_grad), + tuple(o for o in static_outputs if o is not None and o.requires_grad), grad_tensors=tuple(o for o in static_grad_outputs if o is not None), retain_graph=retain_graph_in_backward, ) @@ -794,7 +801,7 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() assert isinstance(static_outputs, tuple) - return tuple(o.detach() for o in static_outputs) + return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @torch.autograd.function.once_differentiable