Skip to content

Commit 3ad243c

Browse files
authored
[Dy2St] Optimize ConstructAttrMapForRunProgram performance (#73682)
1 parent fdbfe0d commit 3ad243c

File tree

4 files changed

+42
-60
lines changed

4 files changed

+42
-60
lines changed

paddle/fluid/pybind/eager_custom_python_api.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ static PyObject *eager_api_run_program(PyObject *self,
9696
}
9797
framework::AttributeMap attrs;
9898
VLOG(6) << "Start PIR ConstructAttrMapFromPyArgs";
99-
ConstructAttrMapForRunProgram(
100-
"run_program", args, 4, PyTuple_GET_SIZE(args), attrs);
99+
ConstructAttrMapForRunProgram("run_program", args, 4, attrs);
101100

102101
VLOG(6) << "Finish Pir ConstructAttrMapFromPyArgs";
103102
tstate = PyEval_SaveThread();

paddle/fluid/pybind/op_function_common.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -846,17 +846,18 @@ void CastPyArg2AttrScalars(PyObject* obj,
846846
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
847847
const std::string& op_type,
848848
ssize_t arg_pos) {
849-
std::vector<std::string> value;
849+
std::vector<std::string_view> views;
850850
if (PyList_Check(obj)) {
851851
Py_ssize_t len = PyList_Size(obj);
852+
views.reserve(len);
852853
PyObject* item = nullptr;
853854
for (Py_ssize_t i = 0; i < len; i++) {
854855
item = PyList_GetItem(obj, i);
855856
if (PyObject_CheckString(item)) {
856857
Py_ssize_t size = 0;
857858
const char* data = nullptr;
858859
data = PyUnicode_AsUTF8AndSize(item, &size);
859-
value.emplace_back(std::string(data, (size_t)size)); // NOLINT
860+
views.emplace_back(std::string_view(data, (size_t)size)); // NOLINT
860861
} else {
861862
PADDLE_THROW(common::errors::InvalidType(
862863
"%s(): argument (position %d) must be "
@@ -869,14 +870,15 @@ std::vector<std::string> CastPyArg2Strings(PyObject* obj,
869870
}
870871
} else if (PyTuple_Check(obj)) {
871872
Py_ssize_t len = PyTuple_Size(obj);
873+
views.reserve(len);
872874
PyObject* item = nullptr;
873875
for (Py_ssize_t i = 0; i < len; i++) {
874876
item = PyTuple_GetItem(obj, i);
875877
if (PyObject_CheckString(item)) {
876878
Py_ssize_t size = 0;
877879
const char* data = nullptr;
878880
data = PyUnicode_AsUTF8AndSize(item, &size);
879-
value.emplace_back(std::string(data, (size_t)size)); // NOLINT
881+
views.emplace_back(std::string_view(data, (size_t)size)); // NOLINT
880882
} else {
881883
PADDLE_THROW(common::errors::InvalidType(
882884
"%s(): argument (position %d) must be "
@@ -895,7 +897,11 @@ std::vector<std::string> CastPyArg2Strings(PyObject* obj,
895897
arg_pos + 1,
896898
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
897899
}
898-
900+
std::vector<std::string> value;
901+
value.reserve(views.size());
902+
for (const auto& view : views) {
903+
value.emplace_back(view);
904+
}
899905
return value;
900906
}
901907

@@ -1207,16 +1213,15 @@ void ConstructAttrMapForLegacyRunProgram(
12071213
void ConstructAttrMapForRunProgram(
12081214
const std::string& op_type,
12091215
PyObject* args,
1210-
ssize_t attr_start,
1211-
ssize_t attr_end,
1216+
ssize_t arg_pos,
12121217
paddle::framework::AttributeMap& attrs) { // NOLINT
1213-
PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2,
1214-
0,
1215-
common::errors::InvalidArgument(
1216-
"The number of arguments for attributes should be even "
1217-
"but attr_start = %d, attr_end = %d.",
1218-
attr_start,
1219-
attr_end));
1218+
PyObject* attrs_dict = PyTuple_GET_ITEM(args, arg_pos);
1219+
if (!PyDict_Check(attrs_dict)) {
1220+
PADDLE_THROW(common::errors::InvalidArgument(
1221+
"%s(): argument must be dict, but got %s",
1222+
op_type,
1223+
reinterpret_cast<PyTypeObject*>(attrs_dict->ob_type)->tp_name));
1224+
}
12201225

12211226
using CastFuncType = void (*)(PyObject*,
12221227
paddle::framework::AttributeMap&,
@@ -1246,34 +1251,29 @@ void ConstructAttrMapForRunProgram(
12461251
{"cuda_graph_dispatch_key", CastPyArg2AttrLong},
12471252
};
12481253

1249-
PyObject* obj = nullptr;
1250-
for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
1251-
VLOG(3) << "Start Process " << arg_pos;
1254+
PyObject *key, *value;
1255+
Py_ssize_t pos = 0;
1256+
while (PyDict_Next(attrs_dict, &pos, &key, &value)) {
12521257
Py_ssize_t key_len = 0;
12531258
const char* key_ptr = nullptr;
1254-
obj = PyTuple_GET_ITEM(args, arg_pos);
1255-
if (PyObject_CheckString(obj)) {
1256-
key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len);
1259+
if (PyObject_CheckString(key)) {
1260+
key_ptr = PyUnicode_AsUTF8AndSize(key, &key_len);
12571261
} else {
12581262
PADDLE_THROW(common::errors::InvalidArgument(
1259-
"%s(): argument (position %d) must be str, but got %s",
1263+
"%s(): dict key must be str, but got %s",
12601264
op_type,
1261-
arg_pos,
1262-
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
1265+
reinterpret_cast<PyTypeObject*>(key->ob_type)->tp_name));
12631266
}
12641267
std::string_view key_view(key_ptr, static_cast<size_t>(key_len));
1265-
VLOG(3) << "Start Process " << key_view;
1266-
obj = PyTuple_GET_ITEM(args, arg_pos + 1);
12671268
auto it = kAttrFuncMap.find(std::string(key_view));
12681269
if (it != kAttrFuncMap.end()) {
1269-
// Call Cast function
1270-
it->second(obj, attrs, std::string(key_view), op_type, arg_pos);
1270+
it->second(value, attrs, std::string(key_view), op_type, 0);
12711271
} else {
12721272
PADDLE_THROW(common::errors::InvalidArgument(
12731273
"Attribute key %.*s is not recognized for operator %s.",
12741274
static_cast<int>(key_view.size()),
12751275
key_view.data(),
1276-
op_type.c_str())); // NOLINT
1276+
op_type.c_str()));
12771277
}
12781278
}
12791279
}

paddle/fluid/pybind/op_function_common.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ void ConstructAttrMapForLegacyRunProgram(
218218
void ConstructAttrMapForRunProgram(
219219
const std::string& op_type,
220220
PyObject* args,
221-
ssize_t attr_start,
222-
ssize_t attr_end,
221+
ssize_t arg_pos,
223222
paddle::framework::AttributeMap& attrs); // NOLINT
224223

225224
unsigned long GetUnsignedLongFromArgs( // NOLINT

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -731,9 +731,9 @@ def __init__(
731731
@staticmethod
732732
def run_impl(partial_program_layer, inputs, parameters, outputs, attrs):
733733
_C_ops.run_program(
734-
inputs,
735-
parameters,
736-
outputs,
734+
PartialProgramLayer._valid_vars(inputs),
735+
PartialProgramLayer._valid_vars(parameters),
736+
PartialProgramLayer._valid_vars(outputs),
737737
partial_program_layer._create_scope_vec(
738738
cache_key=(
739739
PartialProgramLayer._calc_scope_cache_key(
@@ -745,23 +745,21 @@ def run_impl(partial_program_layer, inputs, parameters, outputs, attrs):
745745
),
746746
use_scope_cache=True,
747747
),
748-
*PartialProgramLayer._dict_attributes_to_op_fn_attrs(attrs),
748+
attrs,
749749
)
750750

751751
def __call__(self, inputs):
752752
"""
753753
Execute static graph by Interpreter and Return dynamic Tensors.
754754
"""
755755
attrs = self._prepare_attributes(in_sot_mode=False)
756-
inputs = self._valid_vars(self._prepare_inputs(inputs))
757-
parameters = self._valid_vars(self._params)
756+
inputs = self._prepare_inputs(inputs)
758757
out_vars = self._prepare_outputs()
759-
outputs = self._valid_vars(out_vars)
760758

761759
self.call_run_impl_with_hook(
762760
inputs,
763-
parameters,
764-
outputs,
761+
self._params,
762+
out_vars,
765763
attrs,
766764
)
767765

@@ -773,15 +771,12 @@ def sot_call(self, inputs):
773771
In sot, inputs and outputs of partial program only contain tensors, so we can skip some step to speed up
774772
"""
775773
attrs = self._prepare_attributes(in_sot_mode=True)
776-
inputs = self._valid_vars(inputs)
777-
parameters = self._valid_vars(self._params)
778774
out_vars = self._prepare_outputs()
779-
outputs = self._valid_vars(out_vars)
780775

781776
self.call_run_impl_with_hook(
782777
inputs,
783-
parameters,
784-
outputs,
778+
self._params,
779+
out_vars,
785780
attrs,
786781
)
787782
return self._outputs.quick_restore(out_vars)
@@ -1198,27 +1193,15 @@ def _append_backward(
11981193
return whole_program
11991194

12001195
def _prepare_attributes(self, in_sot_mode=False):
1201-
attrs = {
1196+
return {
12021197
'forward_program': self.program.forward_program,
12031198
'backward_program': self.program.backward_program,
12041199
'is_test': not self.training,
12051200
'program_id': self.program_id,
12061201
'in_sot_mode': in_sot_mode,
12071202
'cuda_graph_state': CUDAGraphState.DISABLE, # default value for not use cuda graph
12081203
'cuda_graph_dispatch_key': 0, # default value for not use cuda graph
1209-
}
1210-
attrs |= self.program.program_attr.items()
1211-
return attrs
1212-
1213-
@staticmethod
1214-
def _dict_attributes_to_op_fn_attrs(attrs):
1215-
op_fn_attrs = []
1216-
for k, v in attrs.items():
1217-
if k == "cuda_graph_state":
1218-
v = int(v)
1219-
op_fn_attrs.append(k)
1220-
op_fn_attrs.append(v)
1221-
return op_fn_attrs
1204+
} | self.program.program_attr
12221205

12231206
def _prepare_inputs(self, inputs):
12241207
"""
@@ -1360,7 +1343,8 @@ def _check_params_all_inited(self, main_program):
13601343
)
13611344
param_and_buffer_names_set.add(var.name)
13621345

1363-
def _valid_vars(self, vars):
1346+
@staticmethod
1347+
def _valid_vars(vars):
13641348
return vars if vars else None
13651349

13661350

0 commit comments

Comments
 (0)