Skip to content

Commit

Permalink
add_share_external_data_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Feb 23, 2022
1 parent 003b8bf commit ad93ebb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 30 deletions.
28 changes: 14 additions & 14 deletions paddle/fluid/inference/api/analysis_predictor_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,19 +420,19 @@ TEST(Tensor, CpuShareExternalData) {
auto w3 = predictor->GetInputHandle("forthw");

std::vector<std::vector<int64_t>> input_data(4, {0, 1, 2, 3});
w0->ShareExternalData<int64_t>(input_data[0].data(), {4, 1}, PlaceType::kCPU);
w1->ShareExternalData<int64_t>(input_data[1].data(), {4, 1}, PlaceType::kCPU);
w2->ShareExternalData<int64_t>(input_data[2].data(), {4, 1}, PlaceType::kCPU);
w3->ShareExternalData<int64_t>(input_data[3].data(), {4, 1}, PlaceType::kCPU);

predictor->Run();
w0->ShareExternalData<int64_t>(input_data[0].data(), {4, 1});
w1->ShareExternalData<int64_t>(input_data[1].data(), {4, 1});
w2->ShareExternalData<int64_t>(input_data[2].data(), {4, 1});
w3->ShareExternalData<int64_t>(input_data[3].data(), {4, 1});

auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
std::vector<float> out_data;
out_data.resize(std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()));
out->ShareExternalData<float>(out_data.data(), out_shape, PlaceType::kCPU);
out->ShareExternalData<float>(out_data.data(), out_shape);

predictor->Run();

PlaceType place;
int size = 0;
Expand Down Expand Up @@ -463,12 +463,10 @@ TEST(Tensor, GpuShareExternalData) {
cudaMemcpyHostToDevice);
}

w0->ShareExternalData<int64_t>(input_gpu[0], {4, 1}, PlaceType::kGPU);
w1->ShareExternalData<int64_t>(input_gpu[1], {4, 1}, PlaceType::kGPU);
w2->ShareExternalData<int64_t>(input_gpu[2], {4, 1}, PlaceType::kGPU);
w3->ShareExternalData<int64_t>(input_gpu[3], {4, 1}, PlaceType::kGPU);

predictor->Run();
w0->ShareExternalData<int64_t>(input_gpu[0], {4, 1});
w1->ShareExternalData<int64_t>(input_gpu[1], {4, 1});
w2->ShareExternalData<int64_t>(input_gpu[2], {4, 1});
w3->ShareExternalData<int64_t>(input_gpu[3], {4, 1});

auto out = predictor->GetOutputHandle("fc_1.tmp_2");
auto out_shape = out->shape();
Expand All @@ -477,7 +475,9 @@ TEST(Tensor, GpuShareExternalData) {
std::multiplies<int>()) *
sizeof(float);
cudaMalloc(reinterpret_cast<void**>(out_data), out_size * sizeof(float));
out->ShareExternalData<float>(out_data, out_shape, PlaceType::kGPU);
out->ShareExternalData<float>(out_data, out_shape);

predictor->Run();

PlaceType place;
int size = 0;
Expand Down
24 changes: 9 additions & 15 deletions paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,20 @@ paddle::experimental::DataLayout LayoutConvert(DataLayout layout) {

template <typename T>
void Tensor::ShareExternalData(const T *data, const std::vector<int> &shape,
PlaceType place, DataLayout layout) {
DataLayout layout) {
EAGER_GET_TENSOR(paddle::framework::LoDTensor)
size_t size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
phi::DenseTensorMeta meta(DataTypeInfo<T>().TYPE, phi::make_ddim(shape),
LayoutConvert(layout));
if (place == PlaceType::kCPU) {
if (place_ == PlaceType::kCPU) {
phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
paddle::platform::CPUPlace()),
meta);
*tensor = std::move(dtensor);
} else if (place == PlaceType::kGPU) {
} else if (place_ == PlaceType::kGPU) {
phi::DenseTensor dtensor(
std::make_shared<phi::Allocation>(const_cast<T *>(data), size,
paddle::platform::CUDAPlace(device_)),
Expand Down Expand Up @@ -403,23 +403,17 @@ template PD_INFER_DECL void Tensor::CopyFromCpu<int8_t>(const int8_t *data);
template PD_INFER_DECL void Tensor::CopyFromCpu<float16>(const float16 *data);

template PD_INFER_DECL void Tensor::ShareExternalData<float>(
const float *data, const std::vector<int> &shape, PlaceType placem,
DataLayout layout);
const float *data, const std::vector<int> &shape, DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int64_t>(
const int64_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
const int64_t *data, const std::vector<int> &shape, DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int32_t>(
const int32_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
const int32_t *data, const std::vector<int> &shape, DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<uint8_t>(
const uint8_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
const uint8_t *data, const std::vector<int> &shape, DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<int8_t>(
const int8_t *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
const int8_t *data, const std::vector<int> &shape, DataLayout layout);
template PD_INFER_DECL void Tensor::ShareExternalData<float16>(
const float16 *data, const std::vector<int> &shape, PlaceType place,
DataLayout layout);
const float16 *data, const std::vector<int> &shape, DataLayout layout);

template PD_INFER_DECL void Tensor::CopyToCpu<float>(float *data) const;
template PD_INFER_DECL void Tensor::CopyToCpu<int64_t>(int64_t *data) const;
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/inference/api/paddle_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ class PD_INFER_DECL Tensor {
/// \param layout The layout of data. Only NCHW is supported now.
template <typename T>
void ShareExternalData(const T* data, const std::vector<int>& shape,
PlaceType place,
DataLayout layout = DataLayout::kNCHW);

/// \brief Experimental interface.
Expand Down

0 comments on commit ad93ebb

Please sign in to comment.