diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 62ba3ee5033e2..16578584cb7c1 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -89,6 +89,10 @@ def FindForwardName(string): return string[:-5] +def IsGradName(string): + return string.endswith("_grad") + + def IsPlainTensorType(string): plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor'] if string in plain_tensor_types: @@ -166,6 +170,12 @@ def GetForwardFunctionName(string): return f"{string}_final_state_dygraph_function" +def TransformGradVarNameForDoubleGradGeneration(string): + if IsGradName(string): + string = "grad_" + string[:-5] + return string + + ###################### ### Yaml Parsers ### ###################### diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 8d061c8929ae6..a601784042163 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -31,6 +31,7 @@ from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads +from codegen_utils import TransformGradVarNameForDoubleGradGeneration from codegen_utils import AssertMessage @@ -146,15 +147,38 @@ class {} : public egr::GradNodeBase {{ }}; """ -FUNCTION_TEMPLATE = \ +GRAD_FUNCTION_TEMPLATE = \ """ std::vector> {}::operator()(std::vector>& grads, bool create_graph) {{ + // Fill Zero For GradIn Tensors {} + + // Apply Gradient Hooks auto hooked_grads = ApplyGradientHooks(grads); + + // Collect GradIn Tensors, Attrs and Recovered TensorWrappers + {} // Call grad_api function VLOG(3) << \"Final State Running: \" << \"{}\"; - auto grad_api_returns = {}{}({}); + {} + + // Get Output + {} + + // Get GradIn autograd_meta + {} + + // Get GradOut autograd_meta + {} + + // Compute Require Grad + {} + + // Create Grad Node + {} + + // Return {} }} """ @@ -170,11 +194,14 @@ class {} : public egr::GradNodeBase {{ // Get Input AutoGradMeta {} // Forward API Call +{} + // Get Outputs {} // Get Output AutoGradMeta {} bool trace_backward = egr::Controller::Instance().HasGrad(); bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); + // Check Inplace & Bump Inplace Version {} {} @@ -225,6 +252,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/phi/api/backward/sparse_bw_api.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" @@ -689,14 +717,11 @@ def GenerateNodeCreationCodes(self): assert name in forward_outputs_position_map.keys( ), AssertMessage(name, forward_outputs_position_map.keys()) fwd_output_pos = forward_outputs_position_map[name][1] - tw_name = f"std::get<{fwd_output_pos}>(api_result)" - else: - tw_name = f"api_result" if is_optional: - set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);" + set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);" else: - set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" + set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) @@ -729,12 +754,8 @@ def GenerateNodeCreationCodes(self): set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" - if num_outputs == 1: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" - set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" - else: - set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" - set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" + set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});" + set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});" set_out_rank_list.append(set_out_rank) set_history_list.append(set_history) set_grad_in_meta_list.append(set_grad_in_meta) @@ -898,20 +919,24 @@ def GenerateForwardDefinition(self, is_inplaced): function_name = GetIntermediateAPIFunctionName(function_name) forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" - - # Get return type list & outputs num_outputs = len(forward_outputs_position_map.keys()) - len( intermediate_outputs) + + # Get Outputs + get_outputs_str = "" + for name, (rtype, pos) in forward_outputs_position_map.items(): + if num_outputs == 1 and len(intermediate_outputs) == 0: + get_outputs_str += f"auto& {name} = api_result;\n" + else: + get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n" + + # Get return type list & outputs returns_type_list = ["" for i in range(num_outputs)] returns_list = ["" for i in range(num_outputs)] for name, (rtype, pos) in forward_outputs_position_map.items(): if name in intermediate_outputs: continue - if num_outputs == 1 and len(intermediate_outputs) == 0: - returns_list[0] = f"api_result" - else: - # Tuple api_result - returns_list[pos] = f"std::get<{pos}>(api_result)" + returns_list[pos] = f"{name}" if IsPlainTensorType(rtype): returns_type_list[pos] = "paddle::experimental::Tensor" @@ -956,26 +981,24 @@ def GenerateForwardDefinition(self, is_inplaced): output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) if num_fwd_outputs == 1: if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" else: # Tuple api_result if IsPlainTensorType(rtype): - output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});" else: assert IsVectorTensorType(rtype) - output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n" + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n" output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" outputs_autograd_meta_list.append(output_autograd_meta) - - # 3. ComputeRequireGrad & PassStopGradient outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - # 4. Check Inplace + # 3. Check Inplace check_inplace_str = "" bump_inplace_version_str = "" if is_inplaced: @@ -1015,7 +1038,7 @@ def GenerateForwardDefinition(self, is_inplaced): self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, - forward_call_str, outputs_autograd_meta_str, + forward_call_str, get_outputs_str, outputs_autograd_meta_str, compute_require_grad_args_str, check_inplace_str, bump_inplace_version_str, node_creation_str, returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" @@ -1083,13 +1106,18 @@ def run(self): class DygraphNodeGenerator(DygraphFunctionGeneratorBase): - def __init__(self, forward_api_contents, grad_api_contents, namespace): + def __init__(self, + forward_api_contents, + grad_api_contents, + namespace, + next_grad_api_contents=None): DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, grad_api_contents, namespace) # Generated Results self.node_declaration_str = "" self.node_definition_str = "" + self.next_grad_api_contents = next_grad_api_contents def GenerateNodeDeclaration(self): forward_op_name = self.forward_api_name @@ -1151,7 +1179,7 @@ def GenerateNodeDeclaration(self): logging.info(f"Generated Node Declaration: {self.node_declaration_str}") - def GenerateNodeDefinition(self): + def GenerateNodeDefinition(self, grad_node_creation_str): namespace = self.namespace forward_api_name = self.forward_api_name backward_api_name = self.backward_api_name @@ -1165,62 +1193,183 @@ def GenerateNodeDefinition(self): grad_api_args_len = len(backward_forward_inputs_map.keys()) + len( backward_grad_inputs_map.keys()) + len(backward_attrs_list) grad_api_args = ["" for i in range(grad_api_args_len)] + get_grad_in_args_list = [] + + # Fill Grad Ins with Zero + fill_zero_str = "" + if forward_api_name in ops_to_fill_zero_for_empty_grads: + fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" + + # Grad Ins from TensorWrappers for name, (_, is_fwd_input, grad_api_position), in backward_forward_inputs_map.items(): tensor_wrapper_name = GetSavedName(name) + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + is_optional = (name in self.optional_inputs) if is_optional: - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" + tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" else: - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_inputs_map.items(): + tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);" + grad_api_args[grad_api_position] = transformed_tensor_name + get_grad_in_args_list.append(tensor_wrapper_recover_str) + + # Grad Ins from grads + for name, (ttype, fwd_position, + grad_api_position) in backward_grad_inputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + if IsPlainTensorType(ttype): - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}][0]" + get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];" else: assert IsVectorTensorType(ttype) - grad_api_args[ - grad_api_position] = f"hooked_grads[{fwd_position}]" + get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];" + grad_api_args[grad_api_position] = transformed_tensor_name + get_grad_in_args_list.append(get_tensor_str) + # Grad Attrs for name, _, _, grad_api_position in backward_attrs_list: saved_attribute_name = GetSavedName(name) - grad_api_args[grad_api_position] = f"this->{saved_attribute_name}" + get_attr_str = f"auto& {name} = this->{saved_attribute_name};" + + grad_api_args[grad_api_position] = name + get_grad_in_args_list.append(get_attr_str) + + get_grad_in_args_str = "\n".join(get_grad_in_args_list) grad_api_args_str = ", ".join(grad_api_args) + # Grad Function Call String + grad_api_namespace = f"paddle::experimental::{namespace}" + grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});" + + # Get Grad Outputs + get_outputs_str = "" + num_outputs = len(backward_grad_outputs_map.keys()) + for name, (ttype, fwd_position, + grad_api_position) in backward_grad_outputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + + if num_outputs == 1: + get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;" + else: + get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{fwd_position}];" + get_outputs_str += get_tensor_str + "\n" + + # Prepare for Node Creation if Necessary + inputs_autograd_meta_str = "" + outputs_autograd_meta_str = "" + compute_require_grad_str = "" + if len(grad_node_creation_str) > 0: + # 1. Get Input AutoGradMeta + inputs_autograd_meta_list = [] + compute_require_grad_args_list = ["trace_backward"] + for name, (ttype, pos, + grad_api_position) in backward_grad_inputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + + input_autograd_meta_name = GetAutoGradMetaName( + transformed_tensor_name) + if IsPlainTensorType(ttype): + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName( + transformed_tensor_name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + + # 2. Get TensorWrapper AutoGradMeta + for name, (ttype, _, pos), in backward_forward_inputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + + input_autograd_meta_name = GetAutoGradMetaName( + transformed_tensor_name) + if IsPlainTensorType(ttype): + input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName( + transformed_tensor_name) + input_autograd_meta = f" std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n" + input_autograd_meta += f" std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) + inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) + compute_require_grad_args_str = ",".join( + compute_require_grad_args_list) + + # 3. Get Output AutoGradMeta + outputs_autograd_meta_list = [] + num_fwd_outputs = len(backward_grad_outputs_map.keys()) + for name, (rtype, pos, _) in backward_grad_outputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + + output_autograd_meta_name = GetAutoGradMetaName( + transformed_tensor_name) + output_autograd_meta_vec_name = GetAutoGradMetaVectorName( + transformed_tensor_name) + if num_fwd_outputs == 1: + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + else: + # Tuple api_result + if IsPlainTensorType(rtype): + output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});" + else: + assert IsVectorTensorType(rtype) + output_autograd_meta = f" std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n" + output_autograd_meta += f" std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" + + outputs_autograd_meta_list.append(output_autograd_meta) + outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) + + compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" + compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});" + # Construct grad_api returns num_bwd_outputs = len(backward_grad_outputs_map.keys()) slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) returns_str = f"std::vector> returns({slot_num_bwd_outputs});\n" - for _, (ttype, fwd_position, - grad_api_position) in backward_grad_outputs_map.items(): + for name, (ttype, fwd_position, + grad_api_position) in backward_grad_outputs_map.items(): + transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( + name) + # Infer Grad API Return Type if num_bwd_outputs == 1: # Single tensor output, return as is if IsPlainTensorType(ttype): - returns_str += "returns[0] = { grad_api_returns };\n" + returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n" else: assert IsVectorTensorType(ttype) - returns_str += "returns[0] = grad_api_returns;\n" + returns_str += f"returns[0] = {transformed_tensor_name};\n" else: # Rearrange output order accordingly - returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n" + returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n" returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str += f"return returns;\n" grad_node_name = GetGradNodeName(forward_api_name) - fill_zero_str = "" - if forward_api_name in ops_to_fill_zero_for_empty_grads: - fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n" - - grad_api_namespace = f"paddle::experimental::{namespace}" - - self.node_definition_str = FUNCTION_TEMPLATE.format( - grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, - backward_api_name, grad_api_args_str, returns_str) + self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format( + grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name, + grad_function_call_str, get_outputs_str, inputs_autograd_meta_str, + outputs_autograd_meta_str, compute_require_grad_str, + grad_node_creation_str, returns_str) logging.info(f"Generated Node Definition: {self.node_definition_str}") @@ -1231,7 +1380,22 @@ def run(self): ## Code Generation ## ##################### self.GenerateNodeDeclaration() - self.GenerateNodeDefinition() + + namespace = self.namespace + grad_node_creation_str = "" + next_grad_api_contents = self.next_grad_api_contents + if next_grad_api_contents: + forward_api_contents = self.grad_api_contents + forward_api_contents['api'] = forward_api_contents['backward_api'] + backward_api_contents = next_grad_api_contents + + next_node_generator = DygraphFunctionGeneratorBase( + forward_api_contents, backward_api_contents, namespace) + next_node_generator.run() + next_node_generator.GenerateNodeCreationCodes() + grad_node_creation_str = next_node_generator.node_creation_str + + self.GenerateNodeDefinition(grad_node_creation_str) class DygraphYamlGenerator(YamlGeneratorBase): @@ -1278,18 +1442,34 @@ def GenerateCode(self): forward_api_contents) if backward_api_contents is None: continue + # Generate Dygraph Forward Function function_generator = DygraphForwardFunctionGenerator( forward_api_contents, backward_api_contents, namespace) function_generator.run() - node_generator = DygraphNodeGenerator( - forward_api_contents, backward_api_contents, namespace) - node_generator.run() - self.forward_definition_str += function_generator.forward_definition_str + "\n" self.forward_declaration_str += function_generator.forward_declaration_str + "\n" - self.node_declaration_str += node_generator.node_declaration_str + "\n" - self.node_definition_str += node_generator.node_definition_str + "\n" + + while True: + next_grad_api_contents = self.GetBackwardAPIContents( + backward_api_contents) + + node_generator = DygraphNodeGenerator( + forward_api_contents, backward_api_contents, namespace, + next_grad_api_contents) + node_generator.run() + self.node_declaration_str += node_generator.node_declaration_str + "\n" + self.node_definition_str += node_generator.node_definition_str + "\n" + + if next_grad_api_contents is None: break + + # Detect if there exists higher-order GradNode + forward_api_contents = backward_api_contents + + # Fake forward_api_content + forward_api_contents['api'] = forward_api_contents[ + 'backward_api'] + backward_api_contents = next_grad_api_contents if len(namespace) > 0: if namespace.endswith("::"): diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 4f2b76db1a27f..b54afd2e133f3 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -649,6 +649,16 @@ kernel : func : put_along_axis_grad +- backward_api : relu_double_grad + forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) + args : (Tensor out, Tensor grad_x_grad) + output : Tensor(out_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [out, out] + kernel : + func : relu_double_grad + - backward_api : relu_grad forward : relu (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -658,6 +668,7 @@ param : [out] kernel : func : relu_grad + backward: relu_double_grad - backward_api : reshape_grad forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)