@@ -846,17 +846,18 @@ void CastPyArg2AttrScalars(PyObject* obj,
846846std::vector<std::string> CastPyArg2Strings (PyObject* obj,
847847 const std::string& op_type,
848848 ssize_t arg_pos) {
849- std::vector<std::string> value ;
849+ std::vector<std::string_view> views ;
850850 if (PyList_Check (obj)) {
851851 Py_ssize_t len = PyList_Size (obj);
852+ views.reserve (len);
852853 PyObject* item = nullptr ;
853854 for (Py_ssize_t i = 0 ; i < len; i++) {
854855 item = PyList_GetItem (obj, i);
855856 if (PyObject_CheckString (item)) {
856857 Py_ssize_t size = 0 ;
857858 const char * data = nullptr ;
858859 data = PyUnicode_AsUTF8AndSize (item, &size);
859- value .emplace_back (std::string (data, (size_t )size)); // NOLINT
860+ views .emplace_back (std::string_view (data, (size_t )size)); // NOLINT
860861 } else {
861862 PADDLE_THROW (common::errors::InvalidType (
862863 " %s(): argument (position %d) must be "
@@ -869,14 +870,15 @@ std::vector<std::string> CastPyArg2Strings(PyObject* obj,
869870 }
870871 } else if (PyTuple_Check (obj)) {
871872 Py_ssize_t len = PyTuple_Size (obj);
873+ views.reserve (len);
872874 PyObject* item = nullptr ;
873875 for (Py_ssize_t i = 0 ; i < len; i++) {
874876 item = PyTuple_GetItem (obj, i);
875877 if (PyObject_CheckString (item)) {
876878 Py_ssize_t size = 0 ;
877879 const char * data = nullptr ;
878880 data = PyUnicode_AsUTF8AndSize (item, &size);
879- value .emplace_back (std::string (data, (size_t )size)); // NOLINT
881+ views .emplace_back (std::string_view (data, (size_t )size)); // NOLINT
880882 } else {
881883 PADDLE_THROW (common::errors::InvalidType (
882884 " %s(): argument (position %d) must be "
@@ -895,7 +897,11 @@ std::vector<std::string> CastPyArg2Strings(PyObject* obj,
895897 arg_pos + 1 ,
896898 ((PyTypeObject*)obj->ob_type )->tp_name )); // NOLINT
897899 }
898-
900+ std::vector<std::string> value;
901+ value.reserve (views.size ());
902+ for (const auto & view : views) {
903+ value.emplace_back (view);
904+ }
899905 return value;
900906}
901907
@@ -1207,16 +1213,15 @@ void ConstructAttrMapForLegacyRunProgram(
12071213void ConstructAttrMapForRunProgram (
12081214 const std::string& op_type,
12091215 PyObject* args,
1210- ssize_t attr_start,
1211- ssize_t attr_end,
1216+ ssize_t arg_pos,
12121217 paddle::framework::AttributeMap& attrs) { // NOLINT
1213- PADDLE_ENFORCE_EQ ((attr_end - attr_start) % 2 ,
1214- 0 ,
1215- common::errors::InvalidArgument (
1216- " The number of arguments for attributes should be even "
1217- " but attr_start = %d, attr_end = %d. " ,
1218- attr_start,
1219- attr_end));
1218+ PyObject* attrs_dict = PyTuple_GET_ITEM (args, arg_pos);
1219+ if (! PyDict_Check (attrs_dict)) {
1220+ PADDLE_THROW ( common::errors::InvalidArgument (
1221+ " %s(): argument must be dict, but got %s " ,
1222+ op_type ,
1223+ reinterpret_cast <PyTypeObject*>(attrs_dict-> ob_type )-> tp_name ));
1224+ }
12201225
12211226 using CastFuncType = void (*)(PyObject*,
12221227 paddle::framework::AttributeMap&,
@@ -1246,34 +1251,29 @@ void ConstructAttrMapForRunProgram(
12461251 {" cuda_graph_dispatch_key" , CastPyArg2AttrLong},
12471252 };
12481253
1249- PyObject* obj = nullptr ;
1250- for ( ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2 ) {
1251- VLOG ( 3 ) << " Start Process " << arg_pos;
1254+ PyObject *key, *value ;
1255+ Py_ssize_t pos = 0 ;
1256+ while ( PyDict_Next (attrs_dict, &pos, &key, &value)) {
12521257 Py_ssize_t key_len = 0 ;
12531258 const char * key_ptr = nullptr ;
1254- obj = PyTuple_GET_ITEM (args, arg_pos);
1255- if (PyObject_CheckString (obj)) {
1256- key_ptr = PyUnicode_AsUTF8AndSize (obj, &key_len);
1259+ if (PyObject_CheckString (key)) {
1260+ key_ptr = PyUnicode_AsUTF8AndSize (key, &key_len);
12571261 } else {
12581262 PADDLE_THROW (common::errors::InvalidArgument (
1259- " %s(): argument (position %d) must be str, but got %s" ,
1263+ " %s(): dict key must be str, but got %s" ,
12601264 op_type,
1261- arg_pos,
1262- ((PyTypeObject*)obj->ob_type )->tp_name )); // NOLINT
1265+ reinterpret_cast <PyTypeObject*>(key->ob_type )->tp_name ));
12631266 }
12641267 std::string_view key_view (key_ptr, static_cast <size_t >(key_len));
1265- VLOG (3 ) << " Start Process " << key_view;
1266- obj = PyTuple_GET_ITEM (args, arg_pos + 1 );
12671268 auto it = kAttrFuncMap .find (std::string (key_view));
12681269 if (it != kAttrFuncMap .end ()) {
1269- // Call Cast function
1270- it->second (obj, attrs, std::string (key_view), op_type, arg_pos);
1270+ it->second (value, attrs, std::string (key_view), op_type, 0 );
12711271 } else {
12721272 PADDLE_THROW (common::errors::InvalidArgument (
12731273 " Attribute key %.*s is not recognized for operator %s." ,
12741274 static_cast <int >(key_view.size ()),
12751275 key_view.data (),
1276- op_type.c_str ())); // NOLINT
1276+ op_type.c_str ()));
12771277 }
12781278 }
12791279}
0 commit comments