Skip to content

Commit

Permalink
Set value with scalar (PaddlePaddle#60452)
Browse files Browse the repository at this point in the history
* set_value with scalar

* fix ut
  • Loading branch information
zoooo0820 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent a8d4bd3 commit cf3746f
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 47 deletions.
92 changes: 52 additions & 40 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1613,58 +1613,70 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Parse values
PADDLE_ENFORCE(
PyCheckTensor(value_obj),
platform::errors::InvalidArgument("The value must be a Tensor"));

std::vector<phi::Scalar> values;
paddle::Tensor value_tensor =
reinterpret_cast<TensorObject*>(value_obj)->tensor;
dealWithValues(tensor, value_obj, &values, has_advanced_index);

if (!has_advanced_index) {
// use set_value OP if there is no advanced index

// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (value_tensor.initialized()) {
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
}
}

// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor);
}
self->tensor = set_value__ad_func(self->tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes,
{1},
values);
}
} else {
// step3.2: Case for there are advanced indexing.
Expand Down
101 changes: 101 additions & 0 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

Expand Down Expand Up @@ -531,5 +533,104 @@ static void ParseBoolAndBroadcastIndices(
}
}

static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
PyObject* value_obj,
std::vector<phi::Scalar>* values,
const bool trans_to_tensor) {
paddle::Tensor value_tensor;
if (PyCheckTensor(value_obj)) {
value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
} else if (py::isinstance<py::array>(value_obj)) {
paddle::Tensor value_tensor_tmp(
std::make_shared<phi::DenseTensor>(),
egr::Controller::Instance().GenerateUniqueName());
py::object value_obj_tmp(py::handle(value_obj), true);
py::object value = value_obj_tmp;
if (tensor.dtype() == phi::DataType::FLOAT32) {
if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::FLOAT64) {
if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::INT32) {
if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::INT64) {
if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::BOOL) {
if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj_tmp);
}
} else if (tensor.dtype() == phi::DataType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj_tmp);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32 or int64, "
"please check the type of tensor."));
}
SetTensorFromPyArray(
static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
value,
tensor.place(),
false);
value_tensor = value_tensor_tmp;
} else {
py::object value_obj_tmp(py::handle(value_obj), true);
// convert the value to self data type
if (py::isinstance<py::float_>(value_obj_tmp) ||
py::isinstance<py::int_>(value_obj_tmp) ||
py::isinstance<py::bool_>(value_obj_tmp) ||
PyComplex_Check(value_obj)) {
if (tensor.dtype() == phi::DataType::FLOAT32 ||
tensor.dtype() == phi::DataType::FLOAT16 ||
tensor.dtype() == phi::DataType::BFLOAT16) {
values->push_back(value_obj_tmp.cast<float>());
} else if (tensor.dtype() == phi::DataType::FLOAT64) {
values->push_back(value_obj_tmp.cast<double>());
} else if (tensor.dtype() == phi::DataType::INT32 ||
tensor.dtype() == phi::DataType::INT16 ||
tensor.dtype() == phi::DataType::INT8 ||
tensor.dtype() == phi::DataType::UINT8) {
values->push_back(value_obj_tmp.cast<float>());
} else if (tensor.dtype() == phi::DataType::INT64) {
values->push_back(value_obj_tmp.cast<double>());
} else if (tensor.dtype() == phi::DataType::BOOL) {
values->push_back(value_obj_tmp.cast<bool>());
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {
values->push_back(value_obj_tmp.cast<std::complex<float>>());
} else if (tensor.dtype() == phi::DataType::COMPLEX128) {
values->push_back(value_obj_tmp.cast<std::complex<double>>());
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign value allows "
"Tensor, numpy.ndarray, integer, float, complex or bool, "
"but received %s.",
Py_TYPE(value_obj)));
}

if (trans_to_tensor) {
value_tensor =
full_ad_func({1}, (*values)[0], tensor.dtype(), tensor.place());
}
}
return value_tensor;
}

} // namespace pybind
} // namespace paddle
11 changes: 4 additions & 7 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def __array__(self, dtype=None):
array = array.astype(dtype)
return array

def pre_deal_index_and_value(self, item, value=None):
def pre_deal_index(self, item):
# since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor
# we call this function in python level.
item = list(item) if isinstance(item, tuple) else [item]
Expand All @@ -985,17 +985,14 @@ def pre_deal_index_and_value(self, item, value=None):
elif isinstance(slice_item, range):
item[i] = paddle.to_tensor(list(slice_item))

if value is not None and not isinstance(value, Variable):
value = paddle.to_tensor(value, dtype=self.dtype)

return tuple(item), value
return tuple(item)

def __getitem__(self, item):
item, _ = pre_deal_index_and_value(self, item)
item = pre_deal_index(self, item)
return self._getitem_dygraph(item)

def __setitem__(self, item, value):
item, value = pre_deal_index_and_value(self, item, value)
item = pre_deal_index(self, item)
return self._setitem_dygraph(item, value)

@framework.dygraph_only
Expand Down

0 comments on commit cf3746f

Please sign in to comment.