Skip to content

Commit

Permalink
refine GetTensorListFromArgs (PaddlePaddle#64045)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored and yinfan98 committed May 7, 2024
1 parent 658f0e1 commit 0613058
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1389,8 +1389,17 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
arg_idx));
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject* tensor_obj = PyList_GetItem(list, i);
PADDLE_ENFORCE_EQ(
PyObject_TypeCheck(tensor_obj, p_tensor_type),
true,
platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors",
op_type,
arg_name,
arg_idx));
paddle::Tensor& tensor =
reinterpret_cast<TensorObject*>(PyList_GetItem(list, i))->tensor;
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
if (local_mesh) {
ConvertToDistTensor(&tensor, local_mesh);
} else {
Expand Down Expand Up @@ -1422,8 +1431,17 @@ std::vector<paddle::Tensor> GetTensorListFromArgs(
arg_idx));
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject* tensor_obj = PyTuple_GetItem(list, i);
PADDLE_ENFORCE_EQ(
PyObject_TypeCheck(tensor_obj, p_tensor_type),
true,
platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors",
op_type,
arg_name,
arg_idx));
paddle::Tensor& tensor =
reinterpret_cast<TensorObject*>(PyTuple_GetItem(list, i))->tensor;
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
if (local_mesh) {
ConvertToDistTensor(&tensor, local_mesh);
} else {
Expand Down Expand Up @@ -1495,8 +1513,17 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
arg_idx));
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject* tensor_obj = PyList_GetItem(list, i);
PADDLE_ENFORCE_EQ(
PyObject_TypeCheck(tensor_obj, p_tensor_type),
true,
platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors",
op_type,
arg_name,
arg_idx));
paddle::Tensor& tensor =
reinterpret_cast<TensorObject*>(PyList_GetItem(list, i))->tensor;
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
if (local_mesh) {
ConvertToDistTensor(&tensor, local_mesh);
} else {
Expand Down Expand Up @@ -1528,8 +1555,17 @@ paddle::optional<std::vector<paddle::Tensor>> GetOptionalTensorListFromArgs(
arg_idx));
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject* tensor_obj = PyTuple_GetItem(list, i);
PADDLE_ENFORCE_EQ(
PyObject_TypeCheck(tensor_obj, p_tensor_type),
true,
platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors",
op_type,
arg_name,
arg_idx));
paddle::Tensor& tensor =
reinterpret_cast<TensorObject*>(PyTuple_GetItem(list, i))->tensor;
reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
if (local_mesh) {
ConvertToDistTensor(&tensor, local_mesh);
} else {
Expand Down

0 comments on commit 0613058

Please sign in to comment.