From 5c73a6eaa2be74eaed7e974b433a4c44f6da58b6 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 10 Jan 2022 11:13:59 +0800 Subject: [PATCH] [Unify Tensors PR #5] framework::Tensor inherits from DenseTensor,test=allcases (#38632) * Added shared_ptr member & corresponding interfaces to Storage * Removed original pten::Allocation from Storage and adjusted the interfaces accordingly * Fixed issues with storage offset * Used place to malloc allocation for TensorStorage * [Unify Tensors PR #3]Ported framework::Tensor interfaces to pten::DenseTensor * Fixed issues with place * Added comments * Moved mutable_data with stream argument to DenseTensor * Added set_offset interface * Fixed CI issues,test=allcases * [Unify Tensors PR #4] Port LoDTensor interfaces to DenseTensor * Removed friend class EigenTensor/EigenMatrix/EigenVector from Tensor * Modified framework::Tensor to inherit from DenseTensor * Reverted changes too pten_layout() interface * Removed friend classes * Rearranged cfunction calls from tensor.data() to tensor.data() * Fixed CI issues * Fixed lite issues * Fixed data() interface issues,test=allcases * Resolved IsInitialized() issues * Fixed ResetHolder() issues * Fixed MKLDNN & Storage issues * Resolved ShareBufferWith() issues * Fixed LoD issues --- .../fluid/distributed/service/brpc_utils.cc | 24 +- .../fluid/distributed/service/heter_client.cc | 11 +- .../fluid/distributed/service/heter_server.h | 2 +- paddle/fluid/framework/CMakeLists.txt | 2 +- .../fluid/framework/copy_same_tensor_test.cc | 4 +- paddle/fluid/framework/data_transform.cc | 1 - .../framework/details/all_reduce_op_handle.cc | 8 +- .../framework/details/broadcast_op_handle.cc | 4 +- .../details/fused_all_reduce_op_handle.cc | 8 +- .../framework/details/reduce_op_handle.cc | 6 +- .../details/sparse_all_reduce_op_handle.cc | 4 +- paddle/fluid/framework/dlpack_tensor.cc | 2 +- paddle/fluid/framework/fleet/ascend_wrapper.h | 3 +- paddle/fluid/framework/fleet/heter_wrapper.cc | 9 +- paddle/fluid/framework/heterxpu_trainer.cc | 12 +- paddle/fluid/framework/lod_tensor.h | 4 +- .../new_executor/standalone_executor_test.cc | 1 - paddle/fluid/framework/parallel_executor.cc | 4 +- paddle/fluid/framework/program_desc.cc | 5 + paddle/fluid/framework/save_load_util.cc | 4 +- paddle/fluid/framework/tensor.cc | 186 +++----------- paddle/fluid/framework/tensor.h | 236 +----------------- paddle/fluid/framework/tensor_impl.h | 55 ---- paddle/fluid/framework/tensor_test.cc | 3 - paddle/fluid/framework/tensor_util.cc | 2 +- paddle/fluid/framework/variable.h | 6 +- paddle/fluid/imperative/all_reduce.cc | 4 +- paddle/fluid/imperative/bkcl_context.cc | 4 +- paddle/fluid/imperative/hccl_context.cc | 4 +- paddle/fluid/imperative/nccl_context.cc | 2 +- .../tests/test_gradient_accmulator.cc | 4 +- paddle/fluid/inference/api/api_impl_tester.cc | 6 +- paddle/fluid/inference/lite/tensor_utils.cc | 4 +- .../amp/update_loss_scaling_op_npu.cc | 2 +- paddle/fluid/operators/coalesce_tensor_op.cc | 8 +- .../fluid/operators/collective/allreduce_op.h | 2 +- .../operators/collective/barrier_op.cu.cc | 2 +- .../operators/collective/broadcast_op.cu.cc | 2 +- .../operators/collective/broadcast_op_xpu.cc | 2 +- .../fluid/operators/collective/c_reduce_op.h | 4 +- paddle/fluid/operators/detection/bbox_util.h | 4 +- paddle/fluid/operators/layer_norm_op.cu | 4 +- .../fluid/operators/math/matrix_inverse.cu.cc | 2 +- paddle/fluid/operators/optimizers/lamb_op.h | 4 +- .../fluid/operators/reader/buffered_reader.cc | 6 +- paddle/fluid/operators/spectral_op.cc | 12 +- paddle/fluid/operators/spectral_op.cu | 18 +- .../operators/tensorrt/tensorrt_engine_op.h | 2 +- .../fluid/platform/device/ipu/ipu_compiler.cc | 2 +- paddle/fluid/platform/device/ipu/ipu_utils.cc | 2 +- paddle/fluid/platform/device/ipu/ipu_utils.h | 2 +- .../platform/device/npu/npu_op_runner.cc | 2 +- .../fluid/platform/device/npu/npu_op_runner.h | 4 +- paddle/fluid/pybind/imperative.cc | 6 +- paddle/fluid/pybind/tensor_py.h | 2 +- paddle/pten/api/lib/utils/storage.h | 17 +- paddle/pten/core/dense_tensor.cc | 195 +++++---------- paddle/pten/core/dense_tensor.h | 31 +-- paddle/pten/core/storage.h | 2 +- paddle/pten/core/utils/intrusive_ptr.h | 4 +- paddle/pten/tests/core/test_dense_tensor.cc | 5 +- 61 files changed, 251 insertions(+), 731 deletions(-) diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc index 92dcde99cccb0..6eb8462977b60 100644 --- a/paddle/fluid/distributed/service/brpc_utils.cc +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -103,19 +103,17 @@ void SerializeLodTensor(framework::Variable* var, if (platform::is_cpu_place(tensor->place())) { auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); iobuf->append(reinterpret_cast(&data_len), 8); - iobuf->append(reinterpret_cast(tensor->data()), - data_len); + iobuf->append(reinterpret_cast(tensor->data()), data_len); } else { #ifdef PADDLE_WITH_CUDA char* temp_ptr = new char[tensor->numel() * framework::SizeOfType(tensor->type())]; auto stream = reinterpret_cast(ctx).stream(); - memory::Copy(platform::CPUPlace(), temp_ptr, - BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), - tensor->data(), - tensor->numel() * framework::SizeOfType(tensor->type()), - stream); + memory::Copy( + platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), stream); auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); iobuf->append(reinterpret_cast(&data_len), 8); iobuf->append(reinterpret_cast(temp_ptr), data_len); @@ -147,19 +145,17 @@ void SerializeSelectedRows(framework::Variable* var, if (platform::is_cpu_place(tensor->place())) { auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); iobuf->append(reinterpret_cast(&data_len), 8); - iobuf->append(reinterpret_cast(tensor->data()), - data_len); + iobuf->append(reinterpret_cast(tensor->data()), data_len); } else { #ifdef PADDLE_WITH_CUDA char* temp_ptr = new char[tensor->numel() * framework::SizeOfType(tensor->type())]; auto stream = reinterpret_cast(ctx).stream(); - memory::Copy(platform::CPUPlace(), temp_ptr, - BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), - tensor->data(), - tensor->numel() * framework::SizeOfType(tensor->type()), - stream); + memory::Copy( + platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), stream); auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); iobuf->append(reinterpret_cast(&data_len), 8); iobuf->append(reinterpret_cast(temp_ptr), data_len); diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc index d9ec6b21fd377..13016d60515dd 100644 --- a/paddle/fluid/distributed/service/heter_client.cc +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -34,7 +34,7 @@ int GetMicroId(const platform::DeviceContext& ctx, auto micro_id = -1; auto* tensor = var->GetMutable(); if (platform::is_cpu_place(tensor->place())) { - auto data = reinterpret_cast(tensor->data()); + auto data = reinterpret_cast(tensor->data()); micro_id = static_cast(data[0]); } else { #ifdef PADDLE_WITH_CUDA @@ -43,11 +43,10 @@ int GetMicroId(const platform::DeviceContext& ctx, char* temp_ptr = temp.data(); auto stream = reinterpret_cast(ctx).stream(); - memory::Copy(platform::CPUPlace(), temp_ptr, - BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), - tensor->data(), - tensor->numel() * framework::SizeOfType(tensor->type()), - stream); + memory::Copy( + platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), stream); float* temp_ptr_float = reinterpret_cast(temp_ptr); micro_id = static_cast(temp_ptr_float[0]); #endif diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h index 5f062755c9242..201074810cf31 100644 --- a/paddle/fluid/distributed/service/heter_server.h +++ b/paddle/fluid/distributed/service/heter_server.h @@ -240,7 +240,7 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { platform::errors::InvalidArgument( "Not find variable microbatch_id in scope.")); auto* tensor = var->GetMutable(); - auto data = reinterpret_cast(tensor->data()); + auto data = reinterpret_cast(tensor->data()); auto micro_id = static_cast(data[0]); int minibatch_index = micro_id / 10; diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bd096f41ccc49..902943d14ff9d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -91,7 +91,7 @@ endif() cc_test(copy_same_tensor_test SRCS copy_same_tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) -cc_library(mixed_vector SRCS mixed_vector.cc DEPS device_context) +cc_library(mixed_vector SRCS mixed_vector.cc DEPS device_context place memory) if(WITH_GPU) nv_test(mixed_vector_test SRCS mixed_vector_test.cc mixed_vector_test.cu DEPS mixed_vector place memory device_context tensor) diff --git a/paddle/fluid/framework/copy_same_tensor_test.cc b/paddle/fluid/framework/copy_same_tensor_test.cc index 0b1fdc3944689..14bef7fe023f6 100644 --- a/paddle/fluid/framework/copy_same_tensor_test.cc +++ b/paddle/fluid/framework/copy_same_tensor_test.cc @@ -77,8 +77,8 @@ static bool CopySameTensorTestMain(const DDim &dims, TensorCopySync(src_tensor, platform::CPUPlace(), &dst_cpu_tensor); } - const void *ground_truth_ptr = src_cpu_tensor.data(); - const void *result_ptr = dst_cpu_tensor.data(); + const void *ground_truth_ptr = src_cpu_tensor.data(); + const void *result_ptr = dst_cpu_tensor.data(); size_t byte_num = product(dims) * sizeof(T); return std::memcmp(ground_truth_ptr, result_ptr, byte_num) == 0; } diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 16c1923ce1815..d8c372becf1b4 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -45,7 +45,6 @@ void TransformData(const OpKernelType &expected_kernel_type, Tensor out; const DataLayout lin = kernel_type_for_var.data_layout_; const DataLayout lout = expected_kernel_type.data_layout_; - // do layout transform if (NeedTransformLayout(lout, lin)) { #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index b1573093ec333..f93202769dbd0 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -153,7 +153,7 @@ void AllReduceOpHandle::AllReduceImpl( "The place type of tensors of the same variable " "in different local scopes should be equal.")); - lod_tensor_data.emplace_back(lod_tensor.data()); + lod_tensor_data.emplace_back(lod_tensor.data()); places.emplace_back(lod_tensor.place()); VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name() @@ -225,7 +225,7 @@ void AllReduceOpHandle::AllReduceFunc( ->GetMutable(); // Reduce All Tensor to trg in CPU - ReduceBufferData func(lod_tensor_data, trg.data(), numel); + ReduceBufferData func(lod_tensor_data, trg.data(), numel); VisitDataType(trg.type(), func); for (size_t i = 1; i < local_exec_scopes_.size(); ++i) { @@ -235,9 +235,9 @@ void AllReduceOpHandle::AllReduceFunc( size_t size = numel * SizeOfType(trg.type()); RunAndRecordEvent(p, [&trg, var, p, size] { - auto dst_ptr = var->GetMutable()->data(); + auto dst_ptr = var->GetMutable()->data(); platform::CPUPlace cpu_place; - memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size); + memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size); }); } } diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index a11a244214d4f..01dc5a45146f1 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -101,7 +101,7 @@ void BroadcastOpHandle::BroadcastOneVar( void *send_recv_buffer = nullptr; if (root_id == dst_id) { - send_recv_buffer = const_cast(in_tensor.data()); + send_recv_buffer = const_cast(in_tensor.data()); out_handle = out_var_handle; } else { send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) @@ -162,7 +162,7 @@ void BroadcastOpHandle::BroadcastOneVar( void *send_recv_buffer = nullptr; if (root_id == dst_id) { - send_recv_buffer = const_cast(in_tensor.data()); + send_recv_buffer = const_cast(in_tensor.data()); out_handle = out_var_handle; } else { send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index bd153f24fa318..b65d4e4fcd55a 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -220,17 +220,17 @@ void FusedAllReduceOpHandle::FusedAllReduceFunc( g_tensor.begin(), g_tensor.end(), [](const std::pair &grad1, const std::pair &grad2) -> bool { - return grad1.second->data() < grad2.second->data(); + return grad1.second->data() < grad2.second->data(); }); size_t size_of_dtype = framework::SizeOfType(dtype); for (size_t k = 1; k < g_tensor.size(); ++k) { - const void *cur_address = g_tensor.at(k - 1).second->data(); + const void *cur_address = g_tensor.at(k - 1).second->data(); int64_t len = g_tensor.at(k - 1).second->numel(); auto offset = platform::Alignment(len * size_of_dtype, places_[0]); void *infer_next_address = reinterpret_cast( reinterpret_cast(cur_address) + offset); - const void *next_address = g_tensor.at(k).second->data(); + const void *next_address = g_tensor.at(k).second->data(); VLOG(10) << string::Sprintf( "Input[%d](%s) address: 0X%02x, Input[%d](%s) address: 0X%02x, Infer " @@ -267,7 +267,7 @@ void FusedAllReduceOpHandle::FusedAllReduceFunc( std::vector lod_tensor_data; lod_tensor_data.reserve(place_num); for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) { - auto data = grads_tensor.at(scope_idx).at(0).second->data(); + auto data = grads_tensor.at(scope_idx).at(0).second->data(); lod_tensor_data.emplace_back(data); } std::vector grad_var_names; diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index bbc458804a195..196f7a3d4a4bf 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -159,7 +159,7 @@ void ReduceOpHandle::RunImpl() { VisitDataType(lod_tensors[0]->type(), func); auto trg = out_var->GetMutable(); - if (reduce_sum_trg.data() != trg->data()) { + if (reduce_sum_trg.data() != trg->data()) { TensorCopy(reduce_sum_trg, platform::CPUPlace(), trg); } } @@ -181,7 +181,7 @@ void ReduceOpHandle::RunImpl() { int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p).device; auto &nccl_ctx = nccl_ctxs_->at(dev_id); - void *buffer = const_cast(lod_tensor.data()); + void *buffer = const_cast(lod_tensor.data()); void *recvbuffer = nullptr; if (root_id == dev_id) { recvbuffer = @@ -227,7 +227,7 @@ void ReduceOpHandle::RunImpl() { int dev_id = BOOST_GET_CONST(platform::XPUPlace, p).device; auto &bkcl_ctx = bkcl_ctxs_->at(dev_id); - void *buffer = const_cast(lod_tensor.data()); + void *buffer = const_cast(lod_tensor.data()); void *recvbuffer = nullptr; if (root_id == dev_id) { recvbuffer = diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index d916b9bc26276..ed485ed587c0b 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -146,7 +146,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &place = places_[i]; auto &in = *ins[i]; - void *in_tensor_buf = const_cast(in.data()); + void *in_tensor_buf = const_cast(in.data()); auto &out = *outs[i]; float *out_tensor_buf = out.data(); @@ -175,7 +175,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { // dgc use ncclAllGather to get all the encoded data // so the buffer need nranks. int buf_size = nranks_ * encode_size; - void *gather_buff = gathers[i]->data(); + void *gather_buff = gathers[i]->data(); VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel << ", nranks:" << nranks_ << ", gather_buf size:" << buf_size diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index cef1016aa5340..95913664961b3 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -134,7 +134,7 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> { DLPackTensor::DLPackTensor(const Tensor &tensor, LaneType lanes) { // init data, data buffer - t_.data = const_cast(tensor.data()); + t_.data = const_cast(tensor.data()); // init device, DLDevice type with device_type and device_id auto place = tensor.place(); diff --git a/paddle/fluid/framework/fleet/ascend_wrapper.h b/paddle/fluid/framework/fleet/ascend_wrapper.h index 82ce3b28776f1..4127adf1bfe27 100644 --- a/paddle/fluid/framework/fleet/ascend_wrapper.h +++ b/paddle/fluid/framework/fleet/ascend_wrapper.h @@ -150,8 +150,7 @@ class AscendInstance { VarTypeToGeType(tensor->type())); tensor_desc.SetRealDimCnt(vec_dim.size()); - const uint8_t *data = - reinterpret_cast(tensor->data()); + const uint8_t *data = reinterpret_cast(tensor->data()); std::vector dst(numel * GeTypeSize(tensor->type())); memcpy(dst.data(), data, GeTypeSize(tensor->type()) * numel); ge::Tensor ge_tensor(tensor_desc, dst); diff --git a/paddle/fluid/framework/fleet/heter_wrapper.cc b/paddle/fluid/framework/fleet/heter_wrapper.cc index a67f9a5e2c733..66f0d116f2412 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.cc +++ b/paddle/fluid/framework/fleet/heter_wrapper.cc @@ -112,20 +112,19 @@ void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope, char* data_ptr = const_cast(req_data->data()); if (platform::is_cpu_place(tensor->place())) { - memcpy(data_ptr, tensor->data(), + memcpy(data_ptr, tensor->data(), tensor->numel() * SizeOfType(tensor->type())); } else { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) memory::Copy(platform::CPUPlace(), data_ptr, BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), - tensor->data(), - tensor->numel() * SizeOfType(tensor->type()), nullptr); + tensor->data(), tensor->numel() * SizeOfType(tensor->type()), + nullptr); #endif #ifdef PADDLE_WITH_XPU memory::Copy(platform::CPUPlace(), data_ptr, BOOST_GET_CONST(platform::XPUPlace, tensor->place()), - tensor->data(), - tensor->numel() * SizeOfType(tensor->type())); + tensor->data(), tensor->numel() * SizeOfType(tensor->type())); #endif } } diff --git a/paddle/fluid/framework/heterxpu_trainer.cc b/paddle/fluid/framework/heterxpu_trainer.cc index ebd737c2d5794..3ed886e874db0 100644 --- a/paddle/fluid/framework/heterxpu_trainer.cc +++ b/paddle/fluid/framework/heterxpu_trainer.cc @@ -339,7 +339,7 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request, auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, thread_tensor->place()).device; platform::CUDADeviceGuard guard(dev_id); - cudaMemset(thread_tensor->data(), 0, + cudaMemset(thread_tensor->data(), 0, thread_tensor->numel() * SizeOfType(thread_tensor->type())); #endif #ifdef PADDLE_WITH_XPU @@ -351,11 +351,11 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request, platform::DeviceContext* dev_ctx = pool.Get(place); const platform::XPUDeviceContext* xpu_ctx = reinterpret_cast(dev_ctx); - xpu::memset(xpu_ctx->x_context(), thread_tensor->data(), 0, + xpu::memset(xpu_ctx->x_context(), thread_tensor->data(), 0, thread_tensor->numel() * SizeOfType(thread_tensor->type())); #endif } else { - memset(thread_tensor->data(), 0, + memset(thread_tensor->data(), 0, thread_tensor->numel() * SizeOfType(thread_tensor->type())); } } @@ -367,7 +367,7 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request, auto dev_id = BOOST_GET_CONST(platform::CUDAPlace, root_tensor->place()).device; platform::CUDADeviceGuard guard(dev_id); - cudaMemset(root_tensor->data(), 0, + cudaMemset(root_tensor->data(), 0, root_tensor->numel() * SizeOfType(root_tensor->type())); #endif #ifdef PADDLE_WITH_XPU @@ -379,11 +379,11 @@ int HeterXpuTrainer::EndPass(const HeterRequest* request, platform::DeviceContext* dev_ctx = pool.Get(place); const platform::XPUDeviceContext* xpu_ctx = reinterpret_cast(dev_ctx); - xpu::memset(xpu_ctx->x_context(), root_tensor->data(), 0, + xpu::memset(xpu_ctx->x_context(), root_tensor->data(), 0, root_tensor->numel() * SizeOfType(root_tensor->type())); #endif } else { - memset(root_tensor->data(), 0, + memset(root_tensor->data(), 0, root_tensor->numel() * SizeOfType(root_tensor->type())); } } diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 7dee0f44e384d..dff6d0e01839a 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -144,8 +144,8 @@ class LoDTensor : public Tensor { */ size_t NumLevels() const { return lod_.size(); } /* - * Number of elements in a level. - */ + * Number of elements in a level. + */ size_t NumElements(size_t level = 0) const { PADDLE_ENFORCE_LT( level, NumLevels(), diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 6876f219c92b9..b42f2da2a4d78 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -71,7 +71,6 @@ ProgramDesc load_from_file(const std::string& file_name) { fin.seekg(0, std::ios::beg); fin.read(&buffer[0], buffer.size()); fin.close(); - ProgramDesc program_desc(buffer); return program_desc; } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ca86cda52fb61..9a38a2d5d6fe8 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -788,7 +788,7 @@ void ParallelExecutor::BCastParamsToDevices( void *buffer; if (i == 0 && trainer_id == 0) { - buffer = const_cast(main_tensor.data()); + buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); @@ -831,7 +831,7 @@ void ParallelExecutor::BCastParamsToDevices( void *buffer; if (i == 0 && trainer_id == 0) { - buffer = const_cast(main_tensor.data()); + buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 4a31adcca65ec..60b93f4a71664 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -101,20 +101,25 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { PADDLE_ENFORCE_EQ(desc_.ParseFromString(binary_str), true, platform::errors::InvalidArgument( "Failed to parse program_desc from binary string.")); + VLOG(1) << 3333; InitFromProto(); } void ProgramDesc::InitFromProto() { + VLOG(1) << 4444; for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } + VLOG(1) << 5555; for (auto &block : blocks_) { for (auto *op : block->AllOps()) { for (const auto &attr : op->Proto()->attrs()) { if (attr.type() == proto::AttrType::BLOCK) { + VLOG(1) << 6666; size_t blk_idx = attr.block_idx(); op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); } else if (attr.type() == proto::AttrType::BLOCKS) { + VLOG(1) << 7777; auto blks_idx = attr.blocks_idx(); std::vector block_descs; for (int blk_idx : blks_idx) { diff --git a/paddle/fluid/framework/save_load_util.cc b/paddle/fluid/framework/save_load_util.cc index 1731a974b71d8..0f1a8e2a9ed5f 100644 --- a/paddle/fluid/framework/save_load_util.cc +++ b/paddle/fluid/framework/save_load_util.cc @@ -295,12 +295,12 @@ bool SaveTensorToDisk(const std::string& file_name, // save tensor uint64_t data_size = tensor->numel() * framework::SizeOfType(tensor->type()); - auto* data_ptr = tensor->data(); + auto* data_ptr = tensor->data(); if (platform::is_gpu_place(tensor->place())) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::Tensor temp; TensorCopySync(*tensor, platform::CPUPlace(), &temp); - data_ptr = temp.data(); + data_ptr = temp.data(); #else PADDLE_THROW(platform::errors::Unavailable( "Tensor is in CUDA device, but paddle not compiled with CUDA.")); diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 2ccd295577039..e5dfe28be7a3c 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/tensor.h" +#include "paddle/pten/api/lib/utils/storage.h" DECLARE_bool(use_stream_safe_cuda_allocator); @@ -26,148 +27,55 @@ class Allocation; namespace paddle { namespace framework { -extern size_t SizeOfType(proto::VarType::Type type); -void Tensor::check_memory_size() const { - PADDLE_ENFORCE_NOT_NULL(holder_, platform::errors::PreconditionNotMet( - "Tensor holds no memory. " - "Call Tensor::mutable_data firstly.")); - size_t size = numel() * SizeOfType(type()); - - PADDLE_ENFORCE_LE( - size, memory_size(), - platform::errors::PreconditionNotMet( - "Tensor's dimension is out of bound." - "Tensor's dimension must be equal or less than the size of its " - "memory." - "But received Tensor's dimension is d%, memory's size is %d.", - size, memory_size())); -} - -Tensor::Tensor(const proto::VarType::Type& dtype) - : type_(dtype), - offset_(0), - inplace_version_counter_(std::make_shared(0)) {} - -size_t Tensor::memory_size() const { - return holder_ == nullptr ? 0UL : holder_->size() - offset_; -} - -void* Tensor::mutable_data(const platform::Place& place, - proto::VarType::Type type, size_t requested_size) { - type_ = type; - PADDLE_ENFORCE_GE( - numel(), 0, - platform::errors::PreconditionNotMet( - "The Tensor's element number must be equal or greater than zero. " - "The Tensor's shape is [", - dims(), "] now")); - size_t size = numel() * SizeOfType(type); - if (requested_size && (requested_size > size)) { - size = requested_size; - } - /* some versions of boost::variant don't have operator!= */ - if (holder_ == nullptr || !(holder_->place() == place) || - holder_->size() < size + offset_) { - // Reset holder first before re-allocate to save memory - holder_.reset(); - holder_ = memory::AllocShared(place, size); - offset_ = 0; - } - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); -} - -void* Tensor::mutable_data(const platform::Place& place, - size_t requested_size) { - PADDLE_ENFORCE_NOT_NULL(this->holder_, platform::errors::PreconditionNotMet( - "The tensor is not initialized.")); - return mutable_data(place, type_, requested_size); -} - -void* Tensor::mutable_data(const platform::Place& place, - proto::VarType::Type type, - const platform::Stream& stream) { - type_ = type; - PADDLE_ENFORCE_GE( - numel(), 0, - platform::errors::PreconditionNotMet( - "The Tensor's element number must be equal or greater than zero. " - "The Tensor's shape is [", - dims(), "] now")); - size_t size = numel() * SizeOfType(type); - - /* some versions of boost::variant don't have operator!= */ - if (holder_ == nullptr || !(holder_->place() == place) || - holder_->size() < size + offset_ || - !(platform::is_gpu_place(place) && - memory::InSameStream(holder_, stream))) { - holder_.reset(); - holder_ = memory::AllocShared(place, size, stream); - offset_ = 0; - } - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); -} - -Tensor& Tensor::ShareDataWith(const Tensor& src) { - src.check_memory_size(); - *this = src; - return *this; -} -Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) { - PADDLE_ENFORCE_NOT_NULL( - inplace_version_counter_, - platform::errors::PreconditionNotMet( - "Tensor does not hold inplace_version_counter_.")); - - inplace_version_counter_ = src.inplace_version_counter_; - return *this; -} Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const { check_memory_size(); - PADDLE_ENFORCE_GE( - begin_idx, 0, - platform::errors::OutOfRange("The start row index must be greater than 0." - "But received the start index is d%.", - begin_idx)); - PADDLE_ENFORCE_LE( - end_idx, dims_[0], - platform::errors::OutOfRange("The end row index is out of bound.")); + PADDLE_ENFORCE_GE(begin_idx, 0, + paddle::platform::errors::OutOfRange( + "The start row index must be greater than 0." + "But received the start index is d%.", + begin_idx)); + PADDLE_ENFORCE_LE(end_idx, meta_.dims[0], + paddle::platform::errors::OutOfRange( + "The end row index is out of bound.")); PADDLE_ENFORCE_LT( begin_idx, end_idx, - platform::errors::InvalidArgument( + paddle::platform::errors::InvalidArgument( "The start row index must be less than the end row index." "But received the start index = %d, the end index = %d.", begin_idx, end_idx)); - if (dims_[0] == 1) { + if (meta_.dims[0] == 1) { return *this; } else { - size_t base = numel() / dims_[0]; + size_t base = numel() / meta_.dims[0]; Tensor dst; - dst.holder_ = holder_; - dst.set_layout(layout_); - dst.type_ = type_; - DDim dst_dims = dims_; + dst.storage_ = pten::make_intrusive( + storage_->data_shared()); + dst.meta_.layout = meta_.layout; + dst.meta_.dtype = meta_.dtype; + DDim dst_dims = meta_.dims; dst_dims[0] = end_idx - begin_idx; dst.Resize(dst_dims); - dst.offset_ = offset_ + begin_idx * base * SizeOfType(type()); + dst.meta_.offset = meta_.offset + begin_idx * base * SizeOf(dtype()); return dst; } } std::vector Tensor::Split(int64_t split_size, int64_t axis) const { check_memory_size(); - PADDLE_ENFORCE_GE(dims_.size(), 0, - platform::errors::OutOfRange( + + PADDLE_ENFORCE_GE(meta_.dims.size(), 0, + paddle::platform::errors::OutOfRange( "split expects at least a 1-dimensional tensor")); + PADDLE_ENFORCE_GE( split_size, 0, - platform::errors::OutOfRange( + paddle::platform::errors::OutOfRange( "split expects split_size be non-negative, but got split_size is %d", split_size)); - int64_t numel_size = dims_[axis]; + + int64_t numel_size = meta_.dims[axis]; int64_t num_splits = 1; if (split_size != 0) { @@ -187,49 +95,33 @@ std::vector Tensor::Split(int64_t split_size, int64_t axis) const { std::vector Tensor::Chunk(int64_t chunks, int64_t axis) const { check_memory_size(); - PADDLE_ENFORCE_GE(dims_.size(), 0, - platform::errors::OutOfRange( + PADDLE_ENFORCE_GE(meta_.dims.size(), 0, + paddle::platform::errors::OutOfRange( "split expects at least a 1-dimensional tensor")); PADDLE_ENFORCE_GE( chunks, 0, - platform::errors::OutOfRange( + paddle::platform::errors::OutOfRange( "chunks expects to be greater than 0, but got chunks is %d", chunks)); - int64_t numel_size = dims_[axis]; + int64_t numel_size = meta_.dims[axis]; int64_t split_size = (numel_size + chunks - 1) / chunks; return Split(split_size, axis); } -Tensor& Tensor::Resize(const DDim& dims) { - dims_ = dims; +Tensor& Tensor::ShareDataWith(const Tensor& src) { + src.check_memory_size(); + *this = src; return *this; } +Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) { + PADDLE_ENFORCE_NOT_NULL( + inplace_version_counter_, + platform::errors::PreconditionNotMet( + "Tensor does not hold inplace_version_counter_.")); -const DDim& Tensor::dims() const { return dims_; } - -int64_t Tensor::numel() const { return product(dims_); } - -void Tensor::ResetHolder(std::shared_ptr holder) { - PADDLE_ENFORCE_EQ( - offset_, 0, - platform::errors::Fatal( - "Only the offset is supported to zero when the holder is reset.")); - if (holder_) { - PADDLE_ENFORCE_LE( - numel() * SizeOfType(type()) + offset_, holder->size(), - paddle::platform::errors::InvalidArgument( - "The size of Holder is not enough to store the Tensor.")); - } - holder_ = holder; -} - -void Tensor::ResetHolderWithType(std::shared_ptr holder, - const proto::VarType::Type& type) { - type_ = type; - ResetHolder(holder); + inplace_version_counter_ = src.inplace_version_counter_; + return *this; } -void Tensor::set_type(const proto::VarType::Type& type) { type_ = type; } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 7eebd97d06523..e86009e9aafea 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -30,6 +30,8 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/stream/stream.h" +#include "paddle/pten/core/dense_tensor.h" + namespace paddle { namespace memory { namespace allocation { @@ -75,98 +77,10 @@ class LoDTensor; Variable object but not a pointer. */ -class TensorInplaceVersion { +class Tensor : public pten::DenseTensor { public: - explicit TensorInplaceVersion(uint32_t inplace_version = 0) - : inplace_version_(inplace_version) {} - bool IsUnique() const { return inplace_version_ == 0; } - void Bump() { ++inplace_version_; } - uint32_t CurrentVersion() const { return inplace_version_; } - void SetInplaceVersionToZero() { inplace_version_ = 0; } - - private: - uint32_t inplace_version_; -}; - -class Tensor { -#ifdef PADDLE_WITH_MKLDNN - - public: - inline dnnl::memory::format_tag format() const { return format_; } - - inline void set_format(const dnnl::memory::format_tag format) { - format_ = format; - } - - protected: - /** - * @brief the detail format of memory block which have layout as kMKLDNN - * - * @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C, - * nChw16c, etc. For a MKLDNN memory block, layout will be set as - * DataLayout::kMKLDNN meanwhile detail memory format will be kept in - * this field. - */ - - dnnl::memory::format_tag format_ = dnnl::memory::format_tag::undef; -#endif - - public: - Tensor() - : type_(proto::VarType::FP32), - offset_(0), - inplace_version_counter_(std::make_shared(0)) {} - - explicit Tensor(const proto::VarType::Type&); - - /*! Return a pointer to mutable memory block. */ - const void* data() const; - - template - T* data(); - - /*! Return a pointer to constant memory block. */ - template - const T* data() const; - - inline bool IsInitialized() const; - - /** - * @brief Return a pointer to mutable memory block. - * @note If not exist, then allocation. - */ - template - T* mutable_data(const platform::Place& place, size_t requested_size = 0); - - void* mutable_data(const platform::Place& place, proto::VarType::Type type, - size_t requested_size = 0); - - void* mutable_data(const platform::Place& place, size_t requested_size = 0); - - void* mutable_data(const platform::Place& place, proto::VarType::Type type, - const platform::Stream& stream); - - /** - * @brief Return a pointer to mutable memory block. - * - * @param[in] dims The dimensions of the memory block. - * @param[in] place The place of the memory block. - * @param[in] requested_size The size of the block in bytes. - * - * @note If not exist, then allocation. - */ - template - T* mutable_data(const DDim& dims, const platform::Place& place, - size_t requested_size = 0); - - /*! Return the dimensions of the memory block. */ - const DDim& dims() const; - - /*! Return the numel of the memory block. */ - int64_t numel() const; - - /*! Resize the dimensions of the memory block. */ - Tensor& Resize(const DDim& dims); + using DenseTensor = pten::DenseTensor; + using DenseTensor::DenseTensor; /*! The internal of two tensors share the same memory block. */ Tensor& ShareDataWith(const Tensor& src); @@ -174,150 +88,16 @@ class Tensor { /*! The internal of two tensors share the same inplace version counter. */ Tensor& ShareInplaceVersionCounterWith(const Tensor& src); - /** - * @brief Return a sub-tensor of the given tensor. - * - * @param[in] begin_idx The index of the start row(inclusive) to slice. - * The index number begins from 0. - * @param[in] end_idx The index of the end row(exclusive) to slice. - * The index number begins from 0. - */ Tensor Slice(int64_t begin_idx, int64_t end_idx) const; - /** - * @brief Return a tensor list of the given tensor. - * - * @param[in] split_size The size of tensor to be split along axis. - * @param[in] axis The axis along which to split. - */ std::vector Split(int64_t split_size, int64_t axis) const; - /** - * @brief Return a tensor list of the given tensor. - * - * @param[in] chunks The number of tensor to be split along axis. - * @param[in] axis The axis along which to split. - */ std::vector Chunk(int64_t chunks, int64_t axis) const; - const platform::Place& place() const { - PADDLE_ENFORCE_NOT_NULL( - holder_, - platform::errors::PreconditionNotMet( - "Tensor not initialized yet when Tensor::place() is called.")); - return holder_->place(); + Tensor& Resize(const DDim& dims) { + meta_.dims = dims; + return *this; } - - proto::VarType::Type type() const { - PADDLE_ENFORCE_NOT_NULL( - holder_, - platform::errors::PreconditionNotMet( - "Tensor not initialized yet when Tensor::type() is called.")); - return type_; - } - - /** - * [Add method get the saved type of tensor] - * - * After the introduction of complex number calculations, Ops that support - * complex number calculations generally support type promotion, such as - * x(float32) + y(complex64) = out(complex64), then the type of the grad - * tensor should be dout(complex64), dx(float32), dy (complex64), but the - * type of dx to be recognized to be float32 by the grad Op relay on the type - * of forward tensor x. But many of our ops have registered InplaceInferer, - * covering the tensor memory of x with out, so as to save storage. - * - * In this case, the dim and type information recorded by x still exist, - * but because x becomes an uninitialized tensor, The type of x record cannot - * be obtained with x.type(), but the type is still valid here, so we - * add saved_type(), This method SHOULD NOT be called by general scenarios. - */ - proto::VarType::Type saved_type() const { return type_; } - - // memory size returns the holding memory size in byte. - size_t memory_size() const; - - void check_memory_size() const; - - DataLayout layout() const { return layout_; } - - void set_layout(const DataLayout layout) { layout_ = layout; } - - void clear() { - holder_ = nullptr; - offset_ = 0; - } - - void ShareBufferWith(const Tensor& tensor) { - holder_ = tensor.holder_; - offset_ = tensor.offset_; - // NOTE(chenfeiyu): when sharing buffer, by definition only holder - // to the memory allocation and offset should be shared. Shape, - // data type, layout, and other metadata associated with a Tensor - // should not be copied. - } - - void ShareDataTypeWith(const Tensor& tensor) { type_ = tensor.type_; } - - bool IsSharedBufferWith(const Tensor& src) const { - return holder_ && holder_ == src.Holder(); - } - - const std::shared_ptr& Holder() const { return holder_; } - size_t offset() const { return offset_; } - void set_offset(size_t offset) { offset_ = offset; } - - std::shared_ptr MoveMemoryHolder() { - return std::move(holder_); - } - - void ResetHolder(std::shared_ptr holder); - - void ResetHolderWithType(std::shared_ptr holder, - const proto::VarType::Type& type); - - void set_type(const proto::VarType::Type& type); - - TensorInplaceVersion& InplaceVersionCounter() { - return *inplace_version_counter_; - } - - private: - /*! holds the memory block if allocated. */ - std::shared_ptr holder_; - proto::VarType::Type type_; - /** - * @brief points to elements dimensions. - * - * @note dims_ do not indicate the memory block size. - */ - - DDim dims_; - - /** - * @brief the layout of memory block, default is NHWC. - * - * @note the memory allocation order, describe how weight/data is stored - * For example, in 4-D Tensor(rank=4), there are three commonly - * used layout. They are - * NCHW, NHWC, CHWN. - * N,C,H,W for respectively the batch size, the number of - * feature maps, the height. - */ - // Fix me: here just change the default layout to kNCHW - // it doesn't fix the real issue, i.e. feeder should set up tensor layout - // according to actual input data - DataLayout layout_ = DataLayout::kNCHW; - - /** - * @brief A PlaceHolder may be shared by more than one tensor. - * - * @note Some of them may be slices of the others. So the offset_ - * is introduced here to indicate the byte offset between - * PlaceHolder::ptr_ and where the tensor data really begins. - */ - size_t offset_; - std::shared_ptr inplace_version_counter_; }; } // namespace framework diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index a83b5d0662bb9..98ad9a629b5ab 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -20,61 +20,6 @@ limitations under the License. */ namespace paddle { namespace framework { -template -inline const T* Tensor::data() const { - check_memory_size(); - bool valid = - std::is_same::value || type_ == DataTypeTrait::DataType(); - PADDLE_ENFORCE_EQ( - valid, true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type, it holds %s, but desires to be %s.", - DataTypeToString(type_), - DataTypeToString(DataTypeTrait::DataType()))); - - return reinterpret_cast( - reinterpret_cast(holder_->ptr()) + offset_); -} - -inline bool Tensor::IsInitialized() const { return holder_ != nullptr; } - -template -inline T* Tensor::data() { - check_memory_size(); - bool valid = - std::is_same::value || type_ == DataTypeTrait::DataType(); - PADDLE_ENFORCE_EQ( - valid, true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type, it holds %s, but desires to be %s", - DataTypeToString(type_), - DataTypeToString(DataTypeTrait::DataType()))); - - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); -} - -inline const void* Tensor::data() const { - check_memory_size(); - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); -} - -template -inline T* Tensor::mutable_data(const DDim& dims, const platform::Place& place, - size_t requested_size) { - static_assert(std::is_pod::value, "T must be POD"); - Resize(dims); - return mutable_data(place, requested_size); -} - -template -inline T* Tensor::mutable_data(const platform::Place& place, - size_t requested_size) { - static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast( - mutable_data(place, DataTypeTrait::DataType(), requested_size)); -} inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { int rank = src.dims().size(); diff --git a/paddle/fluid/framework/tensor_test.cc b/paddle/fluid/framework/tensor_test.cc index 71ff50c92ca59..a58f4a6b5f4c1 100644 --- a/paddle/fluid/framework/tensor_test.cc +++ b/paddle/fluid/framework/tensor_test.cc @@ -45,7 +45,6 @@ TEST(Tensor, DataAssert) { } catch (platform::EnforceNotMet& err) { caught = true; std::string ex_msg = err.what(); - EXPECT_TRUE(ex_msg.find("holder_ should not be null") != std::string::npos); EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call " "Tensor::mutable_data firstly.") != std::string::npos); @@ -189,8 +188,6 @@ TEST(Tensor, ShareDataWith) { } catch (paddle::platform::EnforceNotMet& err) { caught = true; std::string ex_msg = err.what(); - EXPECT_TRUE(ex_msg.find("holder_ should not be null") != - std::string::npos); EXPECT_TRUE(ex_msg.find("Tensor holds no memory. Call " "Tensor::mutable_data firstly.") != std::string::npos); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 6394e84c81a2b..f2323f6e2c6ee 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -45,7 +45,6 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); - dst->Resize(src.dims()); dst->set_layout(src.layout()); auto src_place = src.place(); @@ -442,6 +441,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, auto src_place = src.place(); auto src_ptr = src.data(); auto dst_ptr = dst->mutable_data(dst_place, src.type()); + VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; if (src_ptr == dst_ptr && src_place == dst_place) { VLOG(3) << "Skip copy the same data from " << src_place << " to " diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index f8ad990a668ce..2fa48150903ad 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -72,7 +72,7 @@ class Variable { private: // This method hides type T, so it doesn't appear as a template parameter of // Variable. - framework::TensorInplaceVersion* InplaceVersionCounter(); + pten::TensorInplaceVersion* InplaceVersionCounter(); public: void SetInplaceVersionToZero(); @@ -114,8 +114,8 @@ class Variable { std::shared_ptr holder_; }; -inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { - framework::TensorInplaceVersion* version_counter_ptr(nullptr); +inline pten::TensorInplaceVersion* Variable::InplaceVersionCounter() { + pten::TensorInplaceVersion* version_counter_ptr(nullptr); if (IsType()) { version_counter_ptr = &GetMutable()->InplaceVersionCounter(); diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index 31da214fbc39a..78855cc5c9e2e 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -60,7 +60,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, platform::errors::Unimplemented( "Imperative mode does not support multi-CPU training yet.")); - const void *src_ptr = src.data(); + const void *src_ptr = src.data(); dst->Resize(src.dims()); auto *dst_ptr = dst->mutable_data(src.place(), src.type()); auto nccl_dtype = platform::ToNCCLDataType(src.type()); @@ -129,7 +129,7 @@ static void AllReduce(const framework::SelectedRows &src, auto feature_size = framework::product(dims) / dims[0]; dst_tensor->Resize(dims); auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype); - const auto *src_tensor_ptr = src_tensor.data(); + const auto *src_tensor_ptr = src_tensor.data(); auto sizeof_dtype = framework::SizeOfType(dtype); int64_t row_offset = 0; diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 6569929d6f5d7..2072c41673aaf 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -39,7 +39,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, platform::errors::Unimplemented( "Dynamic graph mode does not support multi-CPU training yet.")); - const void *src_ptr = src.data(); + const void *src_ptr = src.data(); dst->Resize(src.dims()); auto *dst_ptr = dst->mutable_data(src.place(), src.type()); auto bkcl_dtype = platform::ToBKCLDataType(src.type()); @@ -158,7 +158,7 @@ void BKCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { platform::BKCLCommContext::Instance().Get(ring_id, place); XPUStream stream = comm->stream(); - void *src_ptr = src_tensor->data(); + void *src_ptr = src_tensor->data(); auto data_type = platform::ToBKCLDataType(src_tensor->type()); PADDLE_ENFORCE_EQ(bkcl_broadcast(comm->comm(), src_ptr, src_ptr, diff --git a/paddle/fluid/imperative/hccl_context.cc b/paddle/fluid/imperative/hccl_context.cc index 55c52ae6c11de..818b2f424b6af 100644 --- a/paddle/fluid/imperative/hccl_context.cc +++ b/paddle/fluid/imperative/hccl_context.cc @@ -42,7 +42,7 @@ static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, platform::errors::Unimplemented( "Imperative mode does not support multi-CPU training yet.")); - void *src_ptr = const_cast(src.data()); + void *src_ptr = const_cast(src.data()); dst->Resize(src.dims()); void *dst_ptr = dst->mutable_data(src.place(), src.type()); HcclDataType hccl_dtype = platform::ToHCCLDataType(src.type()); @@ -168,7 +168,7 @@ void HCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { aclrtStream stream = comm->stream(); void *src_ptr = - reinterpret_cast(const_cast(src_tensor->data())); + reinterpret_cast(const_cast(src_tensor->data())); auto hccl_dtype = platform::ToHCCLDataType(src_tensor->type()); PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclBroadcast( src_ptr, src_tensor->numel(), hccl_dtype, 0, comm->comm(), diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 15146f6c1204e..1b50c515635d2 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -143,7 +143,7 @@ void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = comm->stream(); - void *src_ptr = src_tensor->data(); + void *src_ptr = src_tensor->data(); auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type()); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream)); diff --git a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc index cb4ab2e79cb99..0a7df9953ad45 100644 --- a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc +++ b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc @@ -176,8 +176,8 @@ static bool IsEqualVar(const framework::Variable& var1, return false; } - auto* t1_p = t1.data(); - auto* t2_p = t2.data(); + auto* t1_p = t1.data(); + auto* t2_p = t2.data(); return std::memcmp(t1_p, t2_p, t1.numel() * framework::SizeOfType(t1.type())) == 0; } diff --git a/paddle/fluid/inference/api/api_impl_tester.cc b/paddle/fluid/inference/api/api_impl_tester.cc index 89aec34110b85..124279d246093 100644 --- a/paddle/fluid/inference/api/api_impl_tester.cc +++ b/paddle/fluid/inference/api/api_impl_tester.cc @@ -37,13 +37,13 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor pt; if (t->type() == framework::proto::VarType::INT64) { - pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); + pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); pt.dtype = PaddleDType::INT64; } else if (t->type() == framework::proto::VarType::FP32) { - pt.data.Reset(t->data(), t->numel() * sizeof(float)); + pt.data.Reset(t->data(), t->numel() * sizeof(float)); pt.dtype = PaddleDType::FLOAT32; } else if (t->type() == framework::proto::VarType::INT32) { - pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); + pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); pt.dtype = PaddleDType::INT32; } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/inference/lite/tensor_utils.cc b/paddle/fluid/inference/lite/tensor_utils.cc index cbc947ea6436a..b1e0eb5ef16ab 100644 --- a/paddle/fluid/inference/lite/tensor_utils.cc +++ b/paddle/fluid/inference/lite/tensor_utils.cc @@ -210,7 +210,7 @@ void TensorCopyAsync(paddle::lite_api::Tensor* dst, const size_t bytes = static_cast(src.numel()) * framework::SizeOfType(src.type()); dst->Resize(framework::vectorize(src.dims())); - const void* src_data = src.data(); + const void* src_data = src.data(); void* dst_data{nullptr}; dst_data = GetLiteTensorDataPtr(dst, GetLitePrecisionType(src.type()), GetLiteTargetType(src.place())); @@ -242,7 +242,7 @@ void TensorCopyAsync(framework::LoDTensor* dst, template <> void TensorDataShare(paddle::lite_api::Tensor* dst, framework::LoDTensor* src) { dst->Resize(framework::vectorize(src->dims())); - dst->ShareExternalMemory(src->data(), src->memory_size(), + dst->ShareExternalMemory(src->data(), src->memory_size(), GetLiteTargetType(src->place())); dst->SetPrecision(GetLitePrecisionType(src->type())); paddle::lite::LoD lite_lod; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc index 021f3a13ce7cf..8160368d72ad1 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -176,7 +176,7 @@ class LazyZerosNPU { NpuOpRunner("ZerosLike", {*zero_tensor}, {*zero_tensor}); runner_zeros.Run(stream); zero_tensor->check_memory_size(); - zero_ptr = zero_tensor->data(); + zero_ptr = zero_tensor->data(); } for (size_t i = 0; i < xs.size(); ++i) { diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 752e5dc4a8772..5655fd25ec24b 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -260,8 +260,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel { size_of_dtype : len; ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")" - << " address: " << out_tensors[i]->data() << " len: " << len - << ", "; + << " address: " << out_tensors[i]->data() << " len: " << len << ", "; offset += len; } PADDLE_ENFORCE_EQ( @@ -300,9 +299,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel { place, align_size) / size_of_dtype : static_cast(size); - const void *ptr = lod_tensors[i]->IsInitialized() - ? lod_tensors[i]->data() - : nullptr; + const void *ptr = + lod_tensors[i]->IsInitialized() ? lod_tensors[i]->data() : nullptr; VLOG(4) << size << " " << len; ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() << ") " diff --git a/paddle/fluid/operators/collective/allreduce_op.h b/paddle/fluid/operators/collective/allreduce_op.h index 4e6d86d49e863..226b2c5132318 100644 --- a/paddle/fluid/operators/collective/allreduce_op.h +++ b/paddle/fluid/operators/collective/allreduce_op.h @@ -43,7 +43,7 @@ class AllReduceOpKernel : public framework::OpKernel { int dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); - auto* sendbuff = in->data(); + auto* sendbuff = in->data(); out->Resize(in->dims()); void* recvbuff = out->mutable_data(place); diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index c9aef237699f3..a98a0bf6ab4a9 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -33,7 +33,7 @@ class BarrierOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); - const void* sendbuff = in->data(); + const void* sendbuff = in->data(); void* recvbuff = out->mutable_data(place); int rid = ctx.Attr("ring_id"); diff --git a/paddle/fluid/operators/collective/broadcast_op.cu.cc b/paddle/fluid/operators/collective/broadcast_op.cu.cc index daaaf8b7a2e41..229d42e64e4e5 100644 --- a/paddle/fluid/operators/collective/broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/broadcast_op.cu.cc @@ -46,7 +46,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { "because this op can only be an In-Place operation.")); void* send_recv_buffer = out->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE_EQ( - send_recv_buffer, in->data(), + send_recv_buffer, in->data(), platform::errors::PreconditionNotMet("Currently, the broadcast op can " "only be an In-Place operation.")); diff --git a/paddle/fluid/operators/collective/broadcast_op_xpu.cc b/paddle/fluid/operators/collective/broadcast_op_xpu.cc index 9cd5c5fd22cc3..e8566803aecfa 100644 --- a/paddle/fluid/operators/collective/broadcast_op_xpu.cc +++ b/paddle/fluid/operators/collective/broadcast_op_xpu.cc @@ -52,7 +52,7 @@ class BKCLBroadcastOpKernel : public framework::OpKernel { "because this op can only be an In-Place operation.")); void* send_recv_buffer = out->mutable_data(ctx.GetPlace()); PADDLE_ENFORCE_EQ( - send_recv_buffer, in->data(), + send_recv_buffer, in->data(), platform::errors::PreconditionNotMet("Currently, the broadcast op can " "only be an In-Place operation.")); diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index b950339bd22be..c06b2683a6bbe 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -213,7 +213,7 @@ class CReduceOpXPUKernel : public framework::OpKernel { auto place = ctx.GetPlace(); BKCLDataType dtype = platform::ToBKCLDataType(in->type()); int64_t numel = in->numel(); - const void* sendbuff = in->data(); + const void* sendbuff = in->data(); out->Resize(in->dims()); void* recvbuff = out->mutable_data(place); @@ -276,7 +276,7 @@ class CReduceOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); - const void* sendbuff = in->data(); + const void* sendbuff = in->data(); out->Resize(in->dims()); void* recvbuff = out->mutable_data(place); diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h index b262f05d6b187..18c45a1a4c6c1 100644 --- a/paddle/fluid/operators/detection/bbox_util.h +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -144,8 +144,8 @@ void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) { static void AppendProposals(framework::Tensor* dst, int64_t offset, const framework::Tensor& src) { - auto* out_data = dst->data(); - auto* to_add_data = src.data(); + auto* out_data = dst->data(); + auto* to_add_data = src.data(); size_t size_of_t = framework::SizeOfType(src.type()); offset *= size_of_t; std::memcpy( diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 3fe453bda2d9e..7725f336416db 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -64,8 +64,8 @@ class LayerNormKernel auto *mean_data = mean->mutable_data(ctx.GetPlace()); auto *var_data = var->mutable_data(ctx.GetPlace()); - auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); - auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); + auto *void_scale_data = (scale == nullptr ? nullptr : scale->data()); + auto *void_bias_data = (bias == nullptr ? nullptr : bias->data()); framework::proto::VarType::Type x_dtype = x->type(); framework::proto::VarType::Type scale_bias_dtype; diff --git a/paddle/fluid/operators/math/matrix_inverse.cu.cc b/paddle/fluid/operators/math/matrix_inverse.cu.cc index 5deedf084c697..7d03f9590357e 100644 --- a/paddle/fluid/operators/math/matrix_inverse.cu.cc +++ b/paddle/fluid/operators/math/matrix_inverse.cu.cc @@ -48,7 +48,7 @@ class MatrixInverseFunctor { memory::Copy(boost::get(context.GetPlace()), tmp_gpu_mat_data->ptr(), boost::get(context.GetPlace()), - a.data(), a.numel() * sizeof(T), context.stream()); + a.data(), a.numel() * sizeof(T), context.stream()); gpu_mat = reinterpret_cast(tmp_gpu_mat_data->ptr()); } diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index df17b5e5f40bc..e3798b49dcbb1 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -492,9 +492,9 @@ class LambOpKernel : public framework::OpKernel { auto trust_ratio_div = ctx.AllocateTmpTensor(param.dims(), dev_ctx); - const void* param_ptr = param.template data(); + const void* param_ptr = param.data(); const void* master_param_ptr = - master_param ? master_param->template data() : nullptr; + master_param ? master_param->data() : nullptr; void* param_out_ptr = param_out.template mutable_data(ctx.GetPlace()); void* master_param_out_ptr = master_param_out diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index 6c28daa7eac72..3c0c8ad1cafce 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -132,7 +132,7 @@ void BufferedReader::ReadAsync(size_t i) { memory::Copy(cuda_pinned_place, cuda_pinned_ptrs[i], BOOST_GET_CONST(platform::CPUPlace, cpu[i].place()), - cpu[i].data(), size); + cpu[i].data(), size); cuda[i].set_lod(cpu[i].lod()); } else { @@ -175,7 +175,7 @@ void BufferedReader::ReadAsync(size_t i) { platform::RecordEvent record_event("BufferedReader:MemoryCopy"); for (size_t i = 0; i < cpu.size(); ++i) { auto cpu_place = cpu[i].place(); - auto cpu_ptr = cpu[i].data(); + auto cpu_ptr = cpu[i].data(); auto gpu_ptr = gpu_ptrs[i]; auto size = cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type()); @@ -239,7 +239,7 @@ void BufferedReader::ReadAsync(size_t i) { platform::RecordEvent record_event("BufferedReader:MemoryCopy"); for (size_t i = 0; i < cpu.size(); ++i) { auto cpu_place = cpu[i].place(); - auto cpu_ptr = cpu[i].data(); + auto cpu_ptr = cpu[i].data(); auto npu_ptr = npu_ptrs[i]; auto size = cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type()); diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index b5edc1dda533b..64751a21c837d 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -587,15 +587,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, collapsed_input_conj.data()); for_range(functor); MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( - desc.get(), collapsed_input_conj.data(), - collapsed_output.data())); + desc.get(), collapsed_input_conj.data(), collapsed_output.data())); } else if (fft_type == FFTTransformType::R2C && !forward) { framework::Tensor collapsed_output_conj(collapsed_output.type()); collapsed_output_conj.mutable_data(collapsed_output.dims(), ctx.GetPlace()); MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( - desc.get(), collapsed_input.data(), - collapsed_output_conj.data())); + desc.get(), collapsed_input.data(), collapsed_output_conj.data())); // conjugate the output platform::ForRange for_range(ctx, collapsed_output.numel()); math::ConjFunctor functor(collapsed_output_conj.data(), @@ -605,12 +603,10 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out, } else { if (forward) { MKL_DFTI_CHECK(platform::dynload::DftiComputeForward( - desc.get(), collapsed_input.data(), - collapsed_output.data())); + desc.get(), collapsed_input.data(), collapsed_output.data())); } else { MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward( - desc.get(), collapsed_input.data(), - collapsed_output.data())); + desc.get(), collapsed_input.data(), collapsed_output.data())); } } diff --git a/paddle/fluid/operators/spectral_op.cu b/paddle/fluid/operators/spectral_op.cu index 2066ce955cafe..d6a775dd55de8 100644 --- a/paddle/fluid/operators/spectral_op.cu +++ b/paddle/fluid/operators/spectral_op.cu @@ -115,22 +115,19 @@ void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config, math::ConjFunctor functor(input->data(), input->numel(), input_conj.data()); for_range(functor); - exec_cufft_plan_raw(config, input_conj.data(), output->data(), - forward); + exec_cufft_plan_raw(config, input_conj.data(), output->data(), forward); } else if (fft_type == FFTTransformType::R2C && !forward) { forward = true; framework::Tensor out_conj(output->type()); out_conj.mutable_data(output->dims(), ctx.GetPlace()); - exec_cufft_plan_raw(config, input->data(), out_conj.data(), - forward); + exec_cufft_plan_raw(config, input->data(), out_conj.data(), forward); platform::ForRange for_range(ctx, output->numel()); math::ConjFunctor functor(out_conj.data(), output->numel(), output->data()); for_range(functor); } else { - exec_cufft_plan_raw(config, input->data(), output->data(), - forward); + exec_cufft_plan_raw(config, input->data(), output->data(), forward); } } @@ -227,22 +224,19 @@ void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config, math::ConjFunctor functor(input->data(), input->numel(), input_conj.data()); for_range(functor); - exec_hipfft_plan_raw(config, input_conj.data(), output->data(), - forward); + exec_hipfft_plan_raw(config, input_conj.data(), output->data(), forward); } else if (fft_type == FFTTransformType::R2C && !forward) { forward = true; framework::Tensor out_conj(output->type()); out_conj.mutable_data(output->dims(), ctx.GetPlace()); - exec_hipfft_plan_raw(config, input->data(), out_conj.data(), - forward); + exec_hipfft_plan_raw(config, input->data(), out_conj.data(), forward); platform::ForRange for_range(ctx, output->numel()); math::ConjFunctor functor(out_conj.data(), output->numel(), output->data()); for_range(functor); } else { - exec_hipfft_plan_raw(config, input->data(), output->data(), - forward); + exec_hipfft_plan_raw(config, input->data(), output->data(), forward); } } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 1c1f63331d056..5ebf67587f3cb 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -405,7 +405,7 @@ class TensorRTEngineOp : public framework::OperatorBase { if (param_names_.count(x)) continue; auto &t = inference::analysis::GetFromScope(scope, x); - calib_data.emplace(x, t.data()); + calib_data.emplace(x, t.data()); } temp_calibrator->setBatch(calib_data); RunNativeImpl(scope, dev_place); diff --git a/paddle/fluid/platform/device/ipu/ipu_compiler.cc b/paddle/fluid/platform/device/ipu/ipu_compiler.cc index a1c5ed4fefbf3..58f784fdbc972 100644 --- a/paddle/fluid/platform/device/ipu/ipu_compiler.cc +++ b/paddle/fluid/platform/device/ipu/ipu_compiler.cc @@ -271,7 +271,7 @@ void Compiler::LowerWeights(const framework::ir::Graph* graph, shape.push_back(tensor.dims().at(i)); } popart::TensorInfo tensor_info(dtype, shape); - popart::ConstVoidData const_data{tensor.data(), tensor_info}; + popart::ConstVoidData const_data{tensor.data(), tensor_info}; popart::TensorId result = builder_->addInitializedInputTensor(const_data, var_name); tensors_.emplace(var_name, result); diff --git a/paddle/fluid/platform/device/ipu/ipu_utils.cc b/paddle/fluid/platform/device/ipu/ipu_utils.cc index 08ba50415dd5f..4dfe8c4efbeb9 100644 --- a/paddle/fluid/platform/device/ipu/ipu_utils.cc +++ b/paddle/fluid/platform/device/ipu/ipu_utils.cc @@ -18,7 +18,7 @@ namespace paddle { namespace platform { namespace ipu { -void* PaddleIArray::data() { return tensor_->data(); } +void* PaddleIArray::data() { return tensor_->data(); } popart::DataType PaddleIArray::dataType() const { return VarType2PopartType(tensor_->type()); diff --git a/paddle/fluid/platform/device/ipu/ipu_utils.h b/paddle/fluid/platform/device/ipu/ipu_utils.h index 670427128b870..3a3b9c8ccc238 100644 --- a/paddle/fluid/platform/device/ipu/ipu_utils.h +++ b/paddle/fluid/platform/device/ipu/ipu_utils.h @@ -84,7 +84,7 @@ std::unique_ptr> Tensor2IArray( popart::TensorInfo tensor_info(dtype, shape); return std::make_unique>( - reinterpret_cast(tensor.data()), tensor_info); + reinterpret_cast(tensor.data()), tensor_info); } template diff --git a/paddle/fluid/platform/device/npu/npu_op_runner.cc b/paddle/fluid/platform/device/npu/npu_op_runner.cc index ed74a94c09502..78e5cb0ab106e 100644 --- a/paddle/fluid/platform/device/npu/npu_op_runner.cc +++ b/paddle/fluid/platform/device/npu/npu_op_runner.cc @@ -401,7 +401,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor, } aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) { - void *ptr = tensor.data(); + void *ptr = tensor.data(); VLOG(4) << "NPU ptr: " << ptr << ", size: " << tensor.memory_size(); auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size()); PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/platform/device/npu/npu_op_runner.h b/paddle/fluid/platform/device/npu/npu_op_runner.h index 39c1fc9d041ea..e83057e682fef 100644 --- a/paddle/fluid/platform/device/npu/npu_op_runner.h +++ b/paddle/fluid/platform/device/npu/npu_op_runner.h @@ -150,8 +150,8 @@ void FillNpuTensorWithConstant(Tensor *tensor, T val) { *npu_pinned_ptr = val; memory::Copy(BOOST_GET_CONST(platform::NPUPlace, tensor->place()), - tensor->data(), npu_pinned_place, npu_pinned_ptr, - sizeof(T), GetCurrentNPUStream()); + tensor->data(), npu_pinned_place, npu_pinned_ptr, sizeof(T), + GetCurrentNPUStream()); auto npu_pinned_allocator = static_cast( diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 00e97bc2db420..4f22e83ac626f 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -792,7 +792,7 @@ void BindImperative(py::module *m_ptr) { SetTensorFromPyArray(&t, array, platform::CPUPlace(), true); // 3. allocate shared memory - void *data_ptr = t.data(); + void *data_ptr = t.data(); size_t data_size = t.numel() * framework::SizeOfType(t.type()); auto shared_writer_holder = memory::allocation::AllocateMemoryMapWriterAllocation(data_size); @@ -827,7 +827,7 @@ void BindImperative(py::module *m_ptr) { SetTensorFromPyArray(&t, array, platform::CPUPlace(), true); // 3. allocate shared memory - void *data_ptr = t.data(); + void *data_ptr = t.data(); size_t data_size = t.numel() * framework::SizeOfType(t.type()); auto shared_writer_holder = memory::allocation::AllocateMemoryMapWriterAllocation(data_size); @@ -1857,7 +1857,7 @@ void BindImperative(py::module *m_ptr) { // 1. get LoDTensor auto *t = self->MutableVar()->GetMutable(); // 2. allocate shared memory - void *data_ptr = t->data(); + void *data_ptr = t->data(); size_t data_size = t->numel() * framework::SizeOfType(t->type()); auto shared_writer_holder = memory::allocation::AllocateMemoryMapWriterAllocation( diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 9d3a858d1bdbf..b31b7456ebca7 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -729,7 +729,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor, numel *= py_dims[i]; } - const void *tensor_buf_ptr = tensor.data(); + const void *tensor_buf_ptr = tensor.data(); std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type()); diff --git a/paddle/pten/api/lib/utils/storage.h b/paddle/pten/api/lib/utils/storage.h index 41b0f4744d12a..e102ecbc5de7d 100644 --- a/paddle/pten/api/lib/utils/storage.h +++ b/paddle/pten/api/lib/utils/storage.h @@ -83,8 +83,21 @@ class SharedStorage : public pten::Storage { size_ = 0; } - size_t size() const noexcept override { return size_; } - const paddle::platform::Place& place() const override { return place_; } + void set_data_shared( + const std::shared_ptr& holder) override { + data_ = holder; + if (holder) { + size_ = holder->size(); + place_ = holder->place(); + } + } + + size_t size() const noexcept override { + return data_ ? data_->size() : size_; + } + const paddle::platform::Place& place() const override { + return data_ ? data_->place() : place_; + } bool OwnsMemory() const noexcept override { return false; } const std::shared_ptr& GetAllocation() { diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index d8d83c575c4cf..1b4254ad2c103 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -41,12 +41,32 @@ DenseTensor::DenseTensor(intrusive_ptr storage, DenseTensor::DenseTensor(intrusive_ptr storage, DenseTensorMeta&& meta) : meta_(std::move(meta)), storage_(std::move(storage)) {} -DenseTensor::DenseTensor(const DenseTensor& other) - : meta_(other.meta()), storage_(copy_intrusive(other.storage_)) {} +DenseTensor::DenseTensor(const DenseTensor& other) : meta_(other.meta()) { + if (storage_ == nullptr) { + storage_ = make_intrusive( + paddle::platform::CPUPlace()); + } + if (other.storage_ != nullptr && other.storage_->data_shared()) { + storage_->set_data_shared(other.storage_->data_shared()); + } + +#ifdef PADDLE_WITH_MKLDNN + format_ = other.format_; +#endif +} DenseTensor& DenseTensor::operator=(const DenseTensor& other) { meta_ = other.meta(); - storage_ = std::move(copy_intrusive(other.storage_)); + if (storage_ == nullptr) { + storage_ = make_intrusive( + paddle::platform::CPUPlace()); + } + if (other.storage_ != nullptr && other.storage_->data_shared()) { + storage_->set_data_shared(other.storage_->data_shared()); + } +#ifdef PADDLE_WITH_MKLDNN + format_ = other.format_; +#endif return *this; } @@ -138,22 +158,22 @@ T* DenseTensor::data() { return reinterpret_cast(data()); } -const void* DenseTensor::data() const { +void* DenseTensor::data() { PADDLE_ENFORCE_NOT_NULL( storage_, paddle::platform::errors::PreconditionNotMet( "The storage must be valid when call the mutable data function.")); - return reinterpret_cast( - reinterpret_cast(storage_->data()) + meta_.offset); + return reinterpret_cast(reinterpret_cast(storage_->data()) + + meta_.offset); } -void* DenseTensor::data() { +const void* DenseTensor::data() const { PADDLE_ENFORCE_NOT_NULL( storage_, paddle::platform::errors::PreconditionNotMet( "The storage must be valid when call the mutable data function.")); - return reinterpret_cast(reinterpret_cast(storage_->data()) + - meta_.offset); + return reinterpret_cast( + reinterpret_cast(storage_->data()) + meta_.offset); } void DenseTensor::set_meta(DenseTensorMeta&& meta) { @@ -174,12 +194,11 @@ void DenseTensor::set_meta(DenseTensorMeta&& meta) { storage_ won't be initialized until the first call to mutable_data(place) */ -DenseTensor& DenseTensor::Resize(const DDim& dims) { +void DenseTensor::Resize(const DDim& dims) { meta_.dims = dims; if (storage_ != nullptr) { mutable_data(); } - return *this; } void DenseTensor::ResetLoD(const LoD& lod) { meta_.lod = lod; } @@ -211,36 +230,21 @@ DATA_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128); /* From framework::Tensor */ /* --------------------------- */ DenseTensor::DenseTensor() { + storage_ = make_intrusive( + paddle::platform::CPUPlace()); inplace_version_counter_ = std::make_shared(0); - meta_ = DenseTensorMeta(); meta_.dtype = paddle::experimental::DataType::FLOAT32; meta_.offset = 0; } DenseTensor::DenseTensor(const paddle::framework::proto::VarType::Type& dtype) { + storage_ = make_intrusive( + paddle::platform::CPUPlace()); inplace_version_counter_ = std::make_shared(0); - meta_ = DenseTensorMeta(); meta_.dtype = TransToPtenDataType(dtype); meta_.offset = 0; } -DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { - src.check_memory_size(); - *this = src; - return *this; -} - -DenseTensor& DenseTensor::ShareInplaceVersionCounterWith( - const DenseTensor& src) { - PADDLE_ENFORCE_NOT_NULL( - inplace_version_counter_, - paddle::platform::errors::PreconditionNotMet( - "Tensor does not hold inplace_version_counter_.")); - - inplace_version_counter_ = src.inplace_version_counter_; - return *this; -} - size_t DenseTensor::memory_size() const { if (storage_ == nullptr || storage_->data_shared() == nullptr) { return 0UL; @@ -304,16 +308,15 @@ void DenseTensor::ResetHolder( paddle::platform::errors::Fatal( "Only the offset is supported to zero when the holder is reset.")); - if (storage_ == nullptr) { - PADDLE_THROW( - paddle::platform::errors::Fatal("storage_ has to be initialized before " - "calling ResetHolder() interface.")); - } + PADDLE_ENFORCE_NOT_NULL( + storage_, + paddle::platform::errors::PreconditionNotMet( + "The storage must be valid when call the mutable data function.")); if (storage_->data_shared()) { PADDLE_ENFORCE_LE( numel() * SizeOf(dtype()) + meta_.offset, - storage_->data_shared()->size(), + holder->size(), paddle::platform::errors::InvalidArgument( "The size of Holder is not enough to store the Tensor.")); } @@ -333,95 +336,6 @@ void DenseTensor::set_type( meta_.dtype = TransToPtenDataType(type); } -DenseTensor DenseTensor::Slice(int64_t begin_idx, int64_t end_idx) const { - check_memory_size(); - PADDLE_ENFORCE_GE(begin_idx, - 0, - paddle::platform::errors::OutOfRange( - "The start row index must be greater than 0." - "But received the start index is d%.", - begin_idx)); - PADDLE_ENFORCE_LE(end_idx, - meta_.dims[0], - paddle::platform::errors::OutOfRange( - "The end row index is out of bound.")); - PADDLE_ENFORCE_LT( - begin_idx, - end_idx, - paddle::platform::errors::InvalidArgument( - "The start row index must be less than the end row index." - "But received the start index = %d, the end index = %d.", - begin_idx, - end_idx)); - - if (meta_.dims[0] == 1) { - return *this; - } else { - size_t base = numel() / meta_.dims[0]; - DenseTensor dst; - dst.storage_ = std::move(copy_intrusive(storage_)); - dst.meta_.layout = meta_.layout; - dst.meta_.dtype = meta_.dtype; - DDim dst_dims = meta_.dims; - dst_dims[0] = end_idx - begin_idx; - dst.Resize(dst_dims); - dst.meta_.offset = meta_.offset + begin_idx * base * SizeOf(dtype()); - return dst; - } -} - -std::vector DenseTensor::Split(int64_t split_size, - int64_t axis) const { - check_memory_size(); - - PADDLE_ENFORCE_GE(meta_.dims.size(), - 0, - paddle::platform::errors::OutOfRange( - "split expects at least a 1-dimensional tensor")); - - PADDLE_ENFORCE_GE( - split_size, - 0, - paddle::platform::errors::OutOfRange( - "split expects split_size be non-negative, but got split_size is %d", - split_size)); - - int64_t numel_size = meta_.dims[axis]; - - int64_t num_splits = 1; - if (split_size != 0) { - num_splits = - std::max((numel_size + split_size - 1) / split_size, 1); - } - - std::vector splits(num_splits); - int64_t last_split_size = split_size - (split_size * num_splits - numel_size); - - for (int64_t i = 0; i < num_splits; ++i) { - int64_t length = i < num_splits - 1 ? split_size : last_split_size; - splits[i] = Slice(i * split_size, i * split_size + length); - } - return splits; -} - -std::vector DenseTensor::Chunk(int64_t chunks, - int64_t axis) const { - check_memory_size(); - PADDLE_ENFORCE_GE(meta_.dims.size(), - 0, - paddle::platform::errors::OutOfRange( - "split expects at least a 1-dimensional tensor")); - PADDLE_ENFORCE_GE( - chunks, - 0, - paddle::platform::errors::OutOfRange( - "chunks expects to be greater than 0, but got chunks is %d", chunks)); - - int64_t numel_size = meta_.dims[axis]; - int64_t split_size = (numel_size + chunks - 1) / chunks; - return Split(split_size, axis); -} - void* DenseTensor::mutable_data(const paddle::platform::Place& place, paddle::framework::proto::VarType::Type type, size_t requested_size) { @@ -447,23 +361,16 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, if (storage_->data_shared() == nullptr || !(storage_->data_shared()->place() == place) || storage_->data_shared()->size() < size + meta_.offset) { - // Reset holder first before re-allocate to save memory storage_->Clear(); storage_->set_data_shared(paddle::memory::AllocShared(place, size)); meta_.offset = 0; } - return reinterpret_cast( - reinterpret_cast(storage_->data_shared()->ptr()) + - meta_.offset); + return reinterpret_cast(reinterpret_cast(storage_->data()) + + meta_.offset); } void* DenseTensor::mutable_data(const paddle::platform::Place& place, size_t requested_size) { - if (storage_ == nullptr) { - PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( - "The tensor is not initialized.")); - } - return mutable_data(place, type(), requested_size); } @@ -481,8 +388,12 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, "] now")); size_t size = numel() * SizeOf(dtype()); + if (storage_ == nullptr) { + storage_ = make_intrusive(place); + } + /* some versions of boost::variant don't have operator!= */ - if (storage_ == nullptr || storage_->data_shared() == nullptr || + if (storage_->data_shared() == nullptr || !(storage_->data_shared()->place() == place) || storage_->data_shared()->size() < size + meta_.offset || !(paddle::platform::is_gpu_place(place) && @@ -491,9 +402,8 @@ void* DenseTensor::mutable_data(const paddle::platform::Place& place, storage_->set_data_shared(paddle::memory::AllocShared(place, size, stream)); meta_.offset = 0; } - return reinterpret_cast( - reinterpret_cast(storage_->data_shared()->ptr()) + - meta_.offset); + return reinterpret_cast(reinterpret_cast(storage_->data()) + + meta_.offset); } /* @jim19930609: The following "mutable_data" only supports specific dtypes @@ -506,7 +416,7 @@ inline T* DenseTensor::mutable_data(const DDim& dims, const paddle::platform::Place& place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); - Resize(dims); + meta_.dims = dims; return mutable_data(place, requested_size); } @@ -518,6 +428,13 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place, place, paddle::framework::DataTypeTrait::DataType(), requested_size)); } +void DenseTensor::ShareBufferWith(const DenseTensor& tensor) { + if (storage_ != nullptr && tensor.storage_ != nullptr) { + storage_->set_data_shared(tensor.storage_->data_shared()); + } + meta_.offset = tensor.meta().offset; +} + #define LEGACY_DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ template dtype* DenseTensor::mutable_data( \ const DDim& dims, \ diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index eb149220f942d..fc92e84f52cea 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -157,7 +157,7 @@ class DenseTensor : public TensorBase, /// \param dims The new dims of the dense tensor. /// \param lod The new lod of the dense tensor. // void Resize(const DDim& dims); - DenseTensor& Resize(const DDim& dims); + void Resize(const DDim& dims); /// \brief Change the lod information in the metadata. /// \param lod The new lod of the dense tensor. @@ -204,7 +204,7 @@ class DenseTensor : public TensorBase, private: friend class CompatibleDenseTensorUtils; - private: + protected: DenseTensorMeta meta_; intrusive_ptr storage_; @@ -228,7 +228,7 @@ class DenseTensor : public TensorBase, explicit DenseTensor(const paddle::framework::proto::VarType::Type& dtype); inline bool IsInitialized() const { - return storage_ != nullptr && storage_->data() != nullptr; + return storage_ != nullptr && storage_->data_shared() != nullptr; } template @@ -256,18 +256,6 @@ class DenseTensor : public TensorBase, paddle::framework::proto::VarType::Type type, const paddle::platform::Stream& stream); - /*! The internal of two tensors share the same memory block. */ - DenseTensor& ShareDataWith(const DenseTensor& src); - - /*! The internal of two tensors share the same inplace version counter. */ - DenseTensor& ShareInplaceVersionCounterWith(const DenseTensor& src); - - DenseTensor Slice(int64_t begin_idx, int64_t end_idx) const; - - std::vector Split(int64_t split_size, int64_t axis) const; - - std::vector Chunk(int64_t chunks, int64_t axis) const; - /* @jim19930609: Remove dependency on protobuf after Tensor Unification. */ paddle::framework::proto::VarType::Type type() const; @@ -288,17 +276,17 @@ class DenseTensor : public TensorBase, meta_.offset = 0; } - void ShareBufferWith(const DenseTensor& tensor) { - storage_ = std::move(copy_intrusive(tensor.storage_)); - meta_.offset = tensor.meta().offset; - } + void ShareBufferWith(const DenseTensor& tensor); void ShareDataTypeWith(const DenseTensor& tensor) { meta_.dtype = tensor.meta().dtype; } bool IsSharedBufferWith(const DenseTensor& src) const { - return IsSharedWith(src); + if (storage_ == nullptr || src.storage_ == nullptr) return false; + if (storage_->data_shared() == src.storage_->data_shared()) return true; + + return false; } const std::shared_ptr Holder() const { @@ -325,7 +313,7 @@ class DenseTensor : public TensorBase, return *inplace_version_counter_; } - private: + protected: std::shared_ptr inplace_version_counter_; /* @jim19930609: This is a hack @@ -365,6 +353,7 @@ class DenseTensor : public TensorBase, Will be adjusted/removed/moved in the near future */ + public: explicit DenseTensor(const LoD& lod); void set_lod(const LoD& lod); diff --git a/paddle/pten/core/storage.h b/paddle/pten/core/storage.h index 74c303697755a..fc56935eeaf19 100644 --- a/paddle/pten/core/storage.h +++ b/paddle/pten/core/storage.h @@ -60,7 +60,7 @@ class Storage : public intrusive_ref_counter { return data_; } - void set_data_shared( + virtual void set_data_shared( const std::shared_ptr& holder) { data_ = holder; } diff --git a/paddle/pten/core/utils/intrusive_ptr.h b/paddle/pten/core/utils/intrusive_ptr.h index 51546bbc5056f..ed9a21e7f3a8a 100644 --- a/paddle/pten/core/utils/intrusive_ptr.h +++ b/paddle/pten/core/utils/intrusive_ptr.h @@ -40,8 +40,8 @@ class intrusive_ptr { rhs.reset(); } - intrusive_ptr& operator=(intrusive_ptr&& rhs) { - px = std::move(rhs.px); + intrusive_ptr& operator=(intrusive_ptr&& rhs) { + swap(rhs); return *this; } diff --git a/paddle/pten/tests/core/test_dense_tensor.cc b/paddle/pten/tests/core/test_dense_tensor.cc index 814f85fde3e40..c6db228c2b757 100644 --- a/paddle/pten/tests/core/test_dense_tensor.cc +++ b/paddle/pten/tests/core/test_dense_tensor.cc @@ -133,7 +133,10 @@ TEST(dense_tensor, shallow_copy) { DenseTensor tensor_1(tensor_0); CHECK(tensor_0.meta() == tensor_1.meta()); - CHECK(tensor_0.release() == tensor_1.release()); + + // Copy constructor: Now shares the underlying shared_ptr instead + // of Storage + CHECK(tensor_0.release() != tensor_1.release()); } } // namespace tests