Skip to content

Commit

Permalink
[Phi] Support construct Scalar by using Non-CPU Tensosr (#41528)
Browse files Browse the repository at this point in the history
* support construct scalar using non-cpu tensor

* fix bugs when run unittest

* fix compile bugs

* fix bugs when run ci

* fix compile bugs

* fix bugs when move copy

* perfect unit test

* perfect unittest

* update according to comment

* add target dependency
  • Loading branch information
YuanRisheng authored Apr 13, 2022
1 parent b239043 commit fe214af
Show file tree
Hide file tree
Showing 19 changed files with 446 additions and 95 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ add_subdirectory(profiler)

cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce dynload_cuda new_profiler stats)
nv_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
elseif(WITH_ROCM)
hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler)
hip_library(profiler SRCS profiler.cc profiler.cu DEPS os_info device_tracer gpu_info enforce new_profiler stats)
hip_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info gpu_info place)
else()
cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler)
cc_library(profiler SRCS profiler.cc DEPS os_info device_tracer enforce new_profiler stats)
cc_library(device_memory_aligment SRCS device_memory_aligment.cc DEPS cpu_info place)
endif()

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ add_subdirectory(tools)
add_subdirectory(tests)

# make an unity target for compile deps
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor)
set(PHI_DEPS convert_utils dense_tensor phi_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos sparse_csr_tensor sparse_coo_tensor string_tensor api_scalar)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})

Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_conte
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform)
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform tensor_copy)

cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl global_utils)
Expand All @@ -173,3 +173,5 @@ cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw p
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
cc_library(strings_api SRCS ${strings_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils)
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api strings_api)
cc_library(tensor_copy SRCS tensor_copy.cc DEPS phi_tensor_raw copy_kernel kernel_dispatch api_gen_utils)
cc_library(api_scalar SRCS scalar.cc DEPS tensor_copy)
30 changes: 2 additions & 28 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/compat/convert_utils.h"
Expand Down Expand Up @@ -424,35 +425,8 @@ std::vector<std::vector<Tensor>> conv2d_grad_impl(
}

Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);

VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());

auto dense_x = TensorToDenseTensor(x);

Tensor out;
auto kernel_out = SetKernelOutput(kernel_key.backend(), &out);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out);

using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
phi::Place,
bool,
phi::DenseTensor*);

auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();

(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);

copy(x, place, blocking, &out);
return out;
}

Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/api/lib/scalar.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/common/scalar.h"

#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace experimental {

template <>
ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
: dtype_(tensor_in.dtype()) { // NOLINT
PADDLE_ENFORCE_EQ(tensor_in.numel(),
1,
phi::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, but "
"now Tensor has `%d` elements",
tensor_in.numel()));
if (tensor_in.place() == PlaceType::kGPU) {
Tensor dst_tensor;
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
GetDataFromTensor(dst_tensor);
} else if (tensor_in.place() == PlaceType::kCPU) {
GetDataFromTensor(tensor_in);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Now, it is not supported to construct Scalar using tensor that its "
"PlaceType is (%d)",
static_cast<int>(tensor_in.place())));
}
}

} // namespace experimental
} // namespace paddle
57 changes: 57 additions & 0 deletions paddle/phi/api/lib/tensor_copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/api/lib/tensor_copy.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace experimental {

void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) {
auto kernel_key_set = ParseKernelKeyByInputArgs(src);
kernel_key_set.backend_set =
kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place));
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);

VLOG(6) << "copy API kernel key: " << kernel_key;
VLOG(6) << "copy API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());

auto dense_x = TensorToDenseTensor(src);

auto kernel_out = SetKernelOutput(kernel_key.backend(), dst);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(*dense_x, &meta_out);

using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
phi::Place,
bool,
phi::DenseTensor*);

auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out);
}

} // namespace experimental
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/phi/api/lib/tensor_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/api/include/tensor.h"

namespace paddle {
namespace experimental {

void copy(const Tensor& src, Place place, bool blocking, Tensor* dst);

} // namespace experimental
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cc_library(phi_api_utils SRCS storage.cc tensor_utils.cc DEPS
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits scalar string_tensor)
tensor_base convert_utils dense_tensor lod_tensor selected_rows_utils place var_type_traits string_tensor scalar)
2 changes: 1 addition & 1 deletion paddle/phi/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cc_library(phi_place SRCS place.cc)
cc_library(scalar SRCS scalar.cc DEPS phi_enforce)
cc_library(scalar SRCS scalar.cc DEPS phi_enforce tensor)
23 changes: 17 additions & 6 deletions paddle/phi/common/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@ limitations under the License. */

#include "paddle/phi/common/scalar.h"

#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace experimental {

// NOTE(xiongkun): why we put definition here?
// test_custom_op can't include enforce.h, because enforce.h includes gflags.
// so we decouple the include dependence of enforce.h by link.
void ThrowTensorConvertError(int num) {
PADDLE_ENFORCE_EQ(num,
// The Tensor must have one dim
template <>
ScalarBase<phi::DenseTensor>::ScalarBase(const phi::DenseTensor& tensor_in)
: dtype_(tensor_in.dtype()) { // NOLINT
PADDLE_ENFORCE_EQ(tensor_in.numel(),
1,
phi::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, but "
"now Tensor has `%d` elements",
num));
tensor_in.numel()));
auto cpu_place = phi::CPUPlace();
if (!paddle::platform::is_same_place(tensor_in.place(), cpu_place)) {
phi::DenseTensor tensor;
framework::TensorCopySync(tensor_in, cpu_place, &tensor);
GetDataFromTensor(tensor);
} else {
GetDataFromTensor(tensor_in);
}
}

} // namespace experimental
Expand Down
90 changes: 44 additions & 46 deletions paddle/phi/common/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle {
namespace experimental {

void ThrowTensorConvertError(int);

template <typename T>
class ScalarBase {
public:
Expand Down Expand Up @@ -105,50 +103,7 @@ class ScalarBase {
}

// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_from_tensor_ = true;
ThrowTensorConvertError(tensor.numel());
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}
ScalarBase(const T& tensor_in); // NOLINT

template <typename OtherT>
ScalarBase(const ScalarBase<OtherT>& other) {
Expand Down Expand Up @@ -200,6 +155,49 @@ class ScalarBase {
private:
template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
void GetDataFromTensor(const T& tensor) {
is_from_tensor_ = true;
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}

private:
bool is_from_tensor_{false};
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ cc_library(string_tensor SRCS string_tensor.cc DEPS convert_utils tensor_meta te

cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy)
cc_library(selected_rows SRCS selected_rows_impl.cc selected_rows.cc DEPS tensor_base dense_tensor phi_enforce ddim memcpy)
cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows)

cc_library(custom_kernel SRCS custom_kernel.cc DEPS kernel_factory)
Expand Down
Loading

0 comments on commit fe214af

Please sign in to comment.