diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 32c520711d978..c777c78b8119e 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -977,6 +977,9 @@ struct SetAttrDescVisitor { void operator()(const std::vector &v) const { // just do nothing. } + void operator()(const std::shared_ptr &v) const { + // just do nothing. + } void operator()(const std::vector &v) const { std::vector var_names; for (auto var : v) { diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 61f133ceb082a..5147a298e6d4d 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -67,7 +67,8 @@ using Attribute = paddle::variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; using AttributeMap = std::unordered_map; using OpCreator = diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index f8f1424ded243..3800eab7c79cc 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -36,6 +36,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/pir/include/core/block.h" #include "paddle/pir/include/core/op_result.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" namespace paddle { @@ -858,6 +859,24 @@ void CastPyArg2AttrIRBlock(PyObject* obj, attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); } +void CastPyArg2AttrIRProgram(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + VLOG(1) << "After Process shared_ptr"; + ::pybind11::object o = ::pybind11::reinterpret_steal<::pybind11::object>(obj); + // ::pybind11::object o = + // ::pybind11::reinterpret_borrow<::pybind11::object>(obj); + // ::pybind11::detail::instance* inst = + // (::pybind11::detail::instance*)obj; // NOLINT + // void** vh = inst->simple_layout ? inst->simple_value_holder + // : + // &inst->nonsimple.values_and_holders[0]; + // attrs[key] = reinterpret_cast>(vh[0]); + attrs[key] = o.cast>(); +} + void CastPyArg2AttrValues(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, @@ -1020,11 +1039,11 @@ void ConstructAttrMapForRunProgram( if (std::set({"cuda_graph_capture_mode"}).count(key)) { CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); - } else if (std::set({"global_block", - "forward_global_block", - "backward_global_block"}) - .count(key)) { + } else if (std::set({"global_block"}).count(key)) { CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"forward_program", "backward_program"}) + .count(key)) { + CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos); } else if (std::set({"is_test", "use_interpretorcore"}) .count(key)) { CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos); diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 3e0a098118931..406f7436f961d 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -870,10 +870,10 @@ def _prune_unused_params(self, program): def _prepare_attributes(self): attrs = [ - 'forward_global_block', - self.program.forward_program.global_block(), - 'backward_global_block', - self.program.backward_program.global_block(), + 'forward_program', + self.program.forward_program, + 'backward_program', + self.program.backward_program, 'is_test', not self.training, 'program_id', diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py index 1bd3a02f54ede..391ee176dfb58 100644 --- a/test/dygraph_to_static/test_no_gradient.py +++ b/test/dygraph_to_static/test_no_gradient.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_pir_only import paddle @@ -33,6 +33,8 @@ def main_func(x, index): class TestNoGradientCase(Dy2StTestBase): + @test_ast_only + @test_pir_only def test_no_gradient(self): paddle.disable_static() x = paddle.randn([10, 3])