Skip to content

Commit

Permalink
Add cast api && Change copy related api to copy_to && add more test (#4)
Browse files Browse the repository at this point in the history
* fix compile error

* wrap framework tensor with LoDTensor

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* add CustomTensor default constructor

* add size() for CustomTensor

* make size const for CustomTensor

* refactor place related api to circle the concept

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* fix compile error

* make place const

* make Tensor copy

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* debug CustomTensor core

* remove additional head of framework

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* use back to shared ptr for custom tensor

* add gpu test

* merge latest cwh code in

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* adjust ut code of custom op

* hid share data from and to

* rename CustomTensor to Tensor

* support multi dtype

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* remove lod, make reshape lowercase, add copy test and refactor copy api

* fix copy to error

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add more test

* add type cast

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* add cast and make copy to api

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* merge cwh code

* add more error log

* add more error log

* polish code

* used for test

* remove test comment

* remove test comment
  • Loading branch information
JiabinYang authored Feb 8, 2021
1 parent 8dc6c94 commit e5cc53b
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 210 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/extension/include/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum DataType {
INT16,
UINT8,
INT8,
BOOL,
// TODO(JiabinYang) support more data types if needed.
};

Expand Down
13 changes: 5 additions & 8 deletions paddle/fluid/extension/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,11 @@ class Tensor {

/// \brief Copy the host memory to tensor data.
/// It's usually used to set the input tensor data.
/// \param data The pointer of the data, from which
/// \param PlaceType of target place, from which
/// the tensor will copy.
template <typename T>
Tensor copy_to_gpu();

/// \brief Copy the tensor data to the host memory.
/// It's usually used to get the output tensor data.
/// \param[out] data The tensor will copy the data to
/// the address.
template <typename T>
Tensor copy_to_cpu();
Tensor copy_to(const PlaceType& place);

/// \brief Return the shape of the Tensor.
std::vector<int> shape() const;
Expand All @@ -89,6 +83,9 @@ class Tensor {
/// \return Place.
const PlaceType& place() const;

/// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type);

private:
friend class framework::CustomTensorUtils;
mutable std::shared_ptr<void> tensor_;
Expand Down
243 changes: 166 additions & 77 deletions paddle/fluid/extension/src/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,76 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {

template <typename InType, typename OutType>
struct CastDataTypeFunctor {
HOSTDEVICE inline OutType operator()(InType in) const {
return static_cast<OutType>(in);
}
};

template <typename InType>
struct CastDataType {
CastDataType(const framework::Tensor &in, framework::Tensor *out,
const platform::DeviceContext *ctx)
: in_(in), out_(out), ctx_(ctx) {}
const framework::Tensor in_;
framework::Tensor *out_;
const platform::DeviceContext *ctx_;

template <typename OutType>
void apply() {
auto *in_begin = in_.data<InType>();
auto *in_end = in_begin + in_.numel();
auto *out_begin = out_->mutable_data<OutType>(in_.place());

if (platform::is_cpu_place(in_.place())) {
platform::Transform<platform::CPUDeviceContext> trans;
auto *context = static_cast<const platform::CPUDeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto *context = static_cast<const platform::CUDADeviceContext *>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
context->Wait();
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Place type is not supported when casting data type."));
}
}
};
template <typename T>
void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
int64_t ele_size) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = paddle::platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kCPU)) {
memory::Copy(platform::CPUPlace(), static_cast<void *>(dst), gpu_place, src,
ele_size, dev_ctx->stream());
} else if ((src_plc == PlaceType::kGPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), gpu_place, src, ele_size,
dev_ctx->stream());
} else if ((src_plc == PlaceType::kCPU) && (dst_plc == PlaceType::kGPU)) {
memory::Copy(gpu_place, static_cast<void *>(dst), platform::CPUPlace(), src,
ele_size, dev_ctx->stream());
} else {
PADDLE_THROW("Only GPU related Copy can reach this func.");
}
cudaStreamSynchronize(dev_ctx->stream());
#endif
}

#define GET_CASTED_TENSOR \
if (!tensor_) { \
tensor_ = std::make_shared<framework::LoDTensor>(); \
Expand Down Expand Up @@ -55,12 +121,12 @@ T *Tensor::mutable_data() {
case static_cast<int>(PlaceType::kCPU): {
return tensor->mutable_data<T>(platform::CPUPlace());
}
case static_cast<int>(PlaceType::kGPU): {
#ifdef PADDLE_WITH_CUDA
case static_cast<int>(PlaceType::kGPU): {
int device_num = platform::GetCurrentDeviceId();
return tensor->mutable_data<T>(platform::CUDAPlace(device_num));
#endif
}
#endif
default:
PADDLE_THROW(platform::errors::Unavailable(
"CustomOp unsupported place: %d", static_cast<int>(place_)));
Expand Down Expand Up @@ -99,99 +165,62 @@ DataType Tensor::type() const {
return DataType::COMPLEX64;
} else if (type == framework::proto::VarType::COMPLEX128) {
return DataType::COMPLEX128;
} else if (type == framework::proto::VarType::BOOL) {
return DataType::BOOL;
}
return DataType::FLOAT32;
}

template <typename T>
Tensor Tensor::copy_to_gpu() {
#ifdef PADDLE_WITH_CUDA
Tensor Tensor::copy_to(const PlaceType &target_place) {
GET_CASTED_TENSOR;
PADDLE_ENFORCE_GE(tensor->numel(), 0,
platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
Tensor target = Tensor(PlaceType::kGPU);
auto *p_src_data = tensor->data<T>();
auto src_place = place();
Tensor target = Tensor(target_place);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
auto p_src_data = tensor->data<T>();

platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
int device_num = platform::GetCurrentDeviceId();
platform::CUDAPlace gpu_place(device_num);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
if (platform::is_cpu_place(tensor->place())) {
memory::Copy(gpu_place, static_cast<void *>(p_target_data),
platform::CPUPlace(), p_src_data, ele_size, dev_ctx->stream());
} else {
memory::Copy(gpu_place, static_cast<void *>(p_target_data), gpu_place,
p_src_data, ele_size, dev_ctx->stream());
}
cudaStreamSynchronize(dev_ctx->stream());
return target;
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle is not compiled with CUDA"));
#endif
return Tensor(PlaceType::kGPU);
}

template <typename T>
Tensor Tensor::copy_to_cpu() {
GET_CASTED_TENSOR;
auto ele_num = tensor->numel();
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();
Tensor target = Tensor(PlaceType::kCPU);
target.reshape(shape());
auto *p_target_data = target.template mutable_data<T>();
if (platform::is_cpu_place(t_place)) {
std::memcpy(static_cast<void *>(p_target_data), t_data,
ele_num * sizeof(T));
if ((src_place == PlaceType::kCPU) && (target_place == PlaceType::kCPU)) {
std::memcpy(static_cast<void *>(p_target_data), p_src_data, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kCPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kCPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else if ((src_place == PlaceType::kGPU) &&
(target_place == PlaceType::kGPU)) {
GpuCopy<T>(p_src_data, p_target_data, src_place, target_place, ele_size);
} else {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(p_target_data),
gpu_place, t_data, ele_num * sizeof(T), dev_ctx->stream());

cudaStreamSynchronize(dev_ctx->stream());
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle is not compiled with CUDA."));
#endif
"Not supported place transform of place: %d to place: %d",
static_cast<int>(src_place), static_cast<int>(target_place)));
}
return target;
}

template Tensor Tensor::copy_to_gpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_gpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_gpu<float>();
template Tensor Tensor::copy_to_gpu<double>();
template Tensor Tensor::copy_to_gpu<int64_t>();
template Tensor Tensor::copy_to_gpu<int32_t>();
template Tensor Tensor::copy_to_gpu<uint8_t>();
template Tensor Tensor::copy_to_gpu<int8_t>();
template Tensor Tensor::copy_to_gpu<int16_t>();

template Tensor Tensor::copy_to_cpu<paddle::platform::float16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::bfloat16>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex64>();
template Tensor Tensor::copy_to_cpu<paddle::platform::complex128>();
template Tensor Tensor::copy_to_cpu<float>();
template Tensor Tensor::copy_to_cpu<double>();
template Tensor Tensor::copy_to_cpu<int64_t>();
template Tensor Tensor::copy_to_cpu<int32_t>();
template Tensor Tensor::copy_to_cpu<uint8_t>();
template Tensor Tensor::copy_to_cpu<int8_t>();
template Tensor Tensor::copy_to_cpu<int16_t>();
template Tensor Tensor::copy_to<paddle::platform::float16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::bfloat16>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex64>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<paddle::platform::complex128>(
const PlaceType &target_place);
template Tensor Tensor::copy_to<float>(const PlaceType &target_place);
template Tensor Tensor::copy_to<double>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int64_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int32_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<uint8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int8_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<int16_t>(const PlaceType &target_place);
template Tensor Tensor::copy_to<bool>(const PlaceType &target_place);

template float *Tensor::data<float>() const;
template double *Tensor::data<double>() const;
Expand All @@ -208,6 +237,7 @@ Tensor::data<paddle::platform::complex128>() const;
template paddle::platform::complex64 *
Tensor::data<paddle::platform::complex64>() const;
template int16_t *Tensor::data<int16_t>() const;
template bool *Tensor::data<bool>() const;

template float *Tensor::mutable_data<float>();
template double *Tensor::mutable_data<double>();
Expand All @@ -224,6 +254,7 @@ Tensor::mutable_data<paddle::platform::complex128>();
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>();
template int16_t *Tensor::mutable_data<int16_t>();
template bool *Tensor::mutable_data<bool>();

template float *Tensor::mutable_data<float>(const PlaceType &place);
template double *Tensor::mutable_data<double>(const PlaceType &place);
Expand All @@ -240,6 +271,7 @@ Tensor::mutable_data<paddle::platform::complex128>(const PlaceType &place);
template paddle::platform::complex64 *
Tensor::mutable_data<paddle::platform::complex64>(const PlaceType &place);
template int16_t *Tensor::mutable_data<int16_t>(const PlaceType &place);
template bool *Tensor::mutable_data<bool>(const PlaceType &place);

std::vector<int> Tensor::shape() const {
GET_CASTED_TENSOR
Expand All @@ -261,6 +293,62 @@ const PlaceType &Tensor::place() const {
return place_;
}

Tensor Tensor::cast(const DataType &target_type) {
GET_CASTED_TENSOR;
Tensor rlt = Tensor(place());
rlt.reshape(this->shape());
auto rlt_tensor_ = static_cast<framework::LoDTensor *>(rlt.tensor_.get());
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto ctx = pool.Get(tensor->place());
auto src_type = tensor->type();
auto dst_type =
framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(target_type);
switch (src_type) {
case framework::proto::VarType::FP16:
framework::VisitDataType(
dst_type, CastDataType<platform::float16>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BF16:
framework::VisitDataType(dst_type, CastDataType<platform::bfloat16>(
*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP32:
framework::VisitDataType(dst_type,
CastDataType<float>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::FP64:
framework::VisitDataType(dst_type,
CastDataType<double>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT32:
framework::VisitDataType(dst_type,
CastDataType<int>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT64:
framework::VisitDataType(
dst_type, CastDataType<int64_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::BOOL:
framework::VisitDataType(dst_type,
CastDataType<bool>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::INT16:
framework::VisitDataType(
dst_type, CastDataType<int16_t>(*tensor, rlt_tensor_, ctx));
break;
case framework::proto::VarType::UINT8:
framework::VisitDataType(
dst_type, CastDataType<u_int8_t>(*tensor, rlt_tensor_, ctx));
break;
// TODO(JiabinYang): Support Complex later
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
framework::DataTypeToString(src_type)));
}
return rlt;
}

int64_t Tensor::size() const {
GET_CASTED_TENSOR;
return tensor->numel();
Expand All @@ -273,12 +361,13 @@ void CustomTensorUtils::ShareDataTo(const paddle::Tensor &src, void *dst) {
*static_cast<framework::LoDTensor *>(src.tensor_.get()));
}

void CustomTensorUtils::ShareDataFrom(void *src, const paddle::Tensor &dst) {
void CustomTensorUtils::ShareDataFrom(const void *src,
const paddle::Tensor &dst) {
if (!dst.tensor_) {
dst.tensor_ = std::make_shared<framework::LoDTensor>();
}
auto *tensor = static_cast<framework::LoDTensor *>(dst.tensor_.get());
tensor->ShareDataWith(*static_cast<framework::LoDTensor *>(src));
tensor->ShareDataWith(*static_cast<const framework::LoDTensor *>(src));
}

} // namespace framework
Expand Down
Loading

0 comments on commit e5cc53b

Please sign in to comment.