diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 3342d8bbd8fac..cb946fb85c09a 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -35,7 +35,7 @@ ELSE () ENDIF() SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") -SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210818") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index a200b948dea45..7730550e061f1 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -183,6 +183,7 @@ function(op_library TARGET) list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM hip_srcs "cholesky_op.cu") + list(REMOVE_ITEM hip_srcs "svd_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 58ae35f268979..3627a8cf71c1e 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -158,6 +158,7 @@ message PipelineConfig { optional int32 micro_batch_size = 1 [ default = 1 ]; optional int32 accumulate_steps = 2 [ default = 1 ]; optional string schedule_mode = 3 [ default = '1F1B' ]; + optional bool p2p_cache_shape = 4 [ default = true ]; } message TensorParallelConfig { diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 362877aa1604e..374984ecdb6b6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -44,7 +44,7 @@ class Optimizer { if (w < optimizer_config::min_bound) w = optimizer_config::min_bound; if (w > optimizer_config::max_bound) w = optimizer_config::max_bound; - add_g2sum = scaled_grad * scaled_grad; + add_g2sum += scaled_grad * scaled_grad; g2sum += add_g2sum; } @@ -64,7 +64,7 @@ class Optimizer { w[i] = optimizer_config::mf_min_bound; if (w[i] > optimizer_config::mf_max_bound) w[i] = optimizer_config::mf_max_bound; - add_g2sum = scaled_grad * scaled_grad; + add_g2sum += scaled_grad * scaled_grad; } g2sum += add_g2sum / n; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index d31627efed7f5..3035f1bcffa47 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -143,8 +143,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, main_program_(main_prog), global_scope_(global_scope), d2h_ctx_pool_({place}), - h2d_ctx_pool_({place}), - fetch_context_pool_({place}) { + h2d_ctx_pool_({place}) { is_build_ = false; garbages_.reset(new GarbageQueue()); @@ -339,9 +338,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node, new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get())); auto* dev_ctx = instr_node->dev_ctx_; - if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") { - dev_ctx = fetch_context_pool_.Get(place); - } Scope scope; instr_node->execution_ctx_.reset(new ExecutionContext( @@ -356,12 +352,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { instr_node.kernel_func_.operator_base_) ->InferShape(instr_node.infershape_ctx_.get()); - if (instr_node.kernel_func_.operator_base_->Type() == "fetch_v2") { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place_); - dev_ctx->Wait(); // TODO(wanghuancoder) - } - instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get()); } @@ -411,8 +401,6 @@ void InterpreterCore::ExecuteInstructionList( working_var_ref); } - fetch_context_pool_.Get(place)->Wait(); - for (size_t i = 0; i < working_var_ref.size(); ++i) { if (working_var_ref[i].var_ref_count_ != 0) { std::cerr << " var ref is not zero " << i << std::endl; @@ -671,6 +659,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, expected_kernel_key); if (!platform::is_same_place(kernel_type_for_var.place_, expected_kernel_key.place_)) { + if (op_base->Type() == "fetch_v2") { + op_base->SetAttr("deepcopy", false); + } // need trans place // 1. add var in scope // 2. add copy op diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index ebb81dc9a09c3..200492ee27ef3 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -114,8 +114,6 @@ class InterpreterCore { size_t max_memory_size_; size_t cur_memory_size_; std::unique_ptr gc_queue_; - - platform::DeviceContextPool fetch_context_pool_; }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 6a9f557770533..57f9d094ac80d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1254,10 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, } #endif #ifdef PADDLE_WITH_XPU - if ((kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op(type_, expected_kernel_key)) || - paddle::platform::is_in_xpu_black_list(type_)) { + if (is_xpu_place(expected_kernel_key.place_) && + (kernel_iter == kernels.end() || + !paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(type_))) { VLOG(3) << "missing XPU kernel: " << type_ << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index 39bc3f040639b..8b16b6a5d007f 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -57,8 +57,6 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc, trainer_desc.downpour_param().stat_var_names(i)); } VLOG(3) << "going to initialize pull dense worker"; - pull_dense_worker_ = PullDenseWorker::GetInstance(); - pull_dense_worker_->Initialize(trainer_desc); SetDebug(trainer_desc.debug()); trainer_desc_ = trainer_desc; workers_.resize(place_num); @@ -112,15 +110,21 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program, } } } + for (auto& var : main_program.Block(0).AllVars()) { + if (var->Persistable()) { + auto it = std::find(need_merge_var_names_.begin(), + need_merge_var_names_.end(), var->Name()); + if (it == need_merge_var_names_.end()) { + VLOG(2) << "train param: " << var->Name(); + trainable_param_.push_back(var->Name()); + } + } + } place_ = place; return; } void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) { - pull_dense_worker_->SetRootScope(root_scope_); - for (size_t i = 0; i < places_.size(); ++i) { - pull_dense_worker_->AddThreadScope(workers_[i]->GetThreadScope()); - } VLOG(3) << "init other env done."; } @@ -141,15 +145,27 @@ Scope* PSGPUTrainer::GetWorkerScope(int thread_id) { return nullptr; } template void PSGPUTrainer::MergeToRootScope(LoDTensor* root_tensor, LoDTensor* tensor) { LoDTensor tmp_root; - TensorCopy(*root_tensor, platform::CPUPlace(), &tmp_root); + TensorCopySync(*root_tensor, platform::CPUPlace(), &tmp_root); T* tmp_root_data = tmp_root.data(); LoDTensor tmp_tensor; - TensorCopy(*tensor, platform::CPUPlace(), &tmp_tensor); + TensorCopySync(*tensor, platform::CPUPlace(), &tmp_tensor); T* data = tmp_tensor.data(); for (int i = 0; i < tmp_tensor.numel(); i++) { tmp_root_data[i] += data[i]; } - TensorCopy(tmp_root, platform::CPUPlace(), root_tensor); + TensorCopySync(tmp_root, platform::CPUPlace(), root_tensor); +} + +void PSGPUTrainer::MergeDenseParam() { + auto thread_scope = workers_[0]->GetThreadScope(); + for (auto& name : trainable_param_) { + VLOG(2) << "merge var " << name << " to root scope"; + Variable* root_var = root_scope_->FindVar(name); + LoDTensor* root_tensor = root_var->GetMutable(); + Variable* var = thread_scope->FindVar(name); + LoDTensor* tensor = var->GetMutable(); + TensorCopySync((*tensor), root_tensor->place(), root_tensor); + } } void PSGPUTrainer::Finalize() { @@ -187,7 +203,7 @@ void PSGPUTrainer::Finalize() { _ForEachDataType_(MergeCallback); } } - pull_dense_worker_->MergeDenseParam(); + MergeDenseParam(); root_scope_->DropKids(); } } // namespace framework diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index fc8fb9327d5bb..0f34c84549f2b 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -265,6 +265,7 @@ class PSGPUTrainer : public TrainerBase { } virtual std::string GetDumpPath(int tid) { return ""; } virtual void InitDumpEnv() {} + virtual void MergeDenseParam(); template void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); @@ -274,6 +275,7 @@ class PSGPUTrainer : public TrainerBase { DownpourWorkerParameter param_; std::map> dense_grad_names_; std::vector need_merge_var_names_; + std::vector trainable_param_; float scale_datanorm_; paddle::platform::Place place_; ProgramDesc program_; diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 93f2fd38a7306..8f45cd0fa6ea1 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -131,10 +131,10 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); #ifdef PADDLE_WITH_XPU - if ((kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key)) || - paddle::platform::is_in_xpu_black_list(op.Type())) { + if (is_xpu_place(expected_kernel_key.place_) && + (kernel_iter == kernels.end() || + !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(op.Type()))) { VLOG(3) << "missing XPU kernel: " << op.Type() << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 6c2fb82cb7cbe..9cd35ad8ad9da 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -225,6 +225,7 @@ size_t Used(const platform::XPUPlace &place) { // For Ascend NPU #ifdef PADDLE_WITH_ASCEND_CL +constexpr int EXTRA_PADDING_SIZE = 32; class NPUBuddyAllocatorList { private: NPUBuddyAllocatorList() : devices_(platform::GetSelectedNPUDevices()) { @@ -257,10 +258,11 @@ class NPUBuddyAllocatorList { std::call_once(*init_flags_[pos], [this, pos] { platform::SetNPUDeviceId(devices_[pos]); - allocators_[pos].reset(new BuddyAllocator( - std::unique_ptr( - new detail::NPUAllocator(devices_[pos])), - platform::NPUMinChunkSize(), platform::NPUMaxChunkSize())); + allocators_[pos].reset( + new BuddyAllocator(std::unique_ptr( + new detail::NPUAllocator(devices_[pos])), + platform::NPUMinChunkSize(), + platform::NPUMaxChunkSize(), EXTRA_PADDING_SIZE)); VLOG(10) << "\n\nNOTE:\n" << "You can set GFlags environment variable " << "'FLAGS_fraction_of_gpu_memory_to_use' " diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 55436f451a41f..e714a020165d1 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -31,9 +31,10 @@ namespace detail { BuddyAllocator::BuddyAllocator( std::unique_ptr system_allocator, size_t min_chunk_size, - size_t max_chunk_size) + size_t max_chunk_size, size_t extra_padding_size) : min_chunk_size_(min_chunk_size), max_chunk_size_(max_chunk_size), + extra_padding_size_(extra_padding_size), cache_(system_allocator->UseGpu()), system_allocator_(std::move(system_allocator)) {} @@ -59,9 +60,14 @@ inline size_t align(size_t size, size_t alignment) { void* BuddyAllocator::Alloc(size_t unaligned_size) { // adjust allocation alignment - size_t size = - align(unaligned_size + sizeof(MemoryBlock::Desc), min_chunk_size_); + size_t size = + align(unaligned_size + sizeof(MemoryBlock::Desc) + extra_padding_size_, + min_chunk_size_); + VLOG(10) << "alloc: " << unaligned_size + << ", padding for desc: " << sizeof(MemoryBlock::Desc) + << ", extra padding: " << extra_padding_size_ + << ", alignment: " << min_chunk_size_; // acquire the allocator lock std::lock_guard lock(mutex_); diff --git a/paddle/fluid/memory/detail/buddy_allocator.h b/paddle/fluid/memory/detail/buddy_allocator.h index 135c3b6d04f34..2ded5dccf6ee0 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.h +++ b/paddle/fluid/memory/detail/buddy_allocator.h @@ -35,7 +35,8 @@ namespace detail { class BuddyAllocator { public: BuddyAllocator(std::unique_ptr system_allocator, - size_t min_chunk_size, size_t max_chunk_size); + size_t min_chunk_size, size_t max_chunk_size, + size_t extra_padding_size = 0); ~BuddyAllocator(); @@ -86,7 +87,9 @@ class BuddyAllocator { size_t min_chunk_size_; // the minimum size of each chunk size_t max_chunk_size_; // the maximum size of each chunk - size_t realloc_size_ = 0; // the size of re-allocated chunk + size_t realloc_size_ = 0; // the size of re-allocated chunk + size_t extra_padding_size_ = 0; // the size of padding to the size of memory + // to alloc, especially used in NPU private: /** diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc index 6db18c46a09b8..0046440429fe2 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" +DECLARE_int32(min_loss_scaling); + namespace paddle { namespace operators { @@ -49,7 +51,7 @@ void Update(const platform::NPUDeviceContext& ctx, std::vector bad_out_data; TensorToVector(*bad_out_tensor, ctx, &bad_out_data); - if (bad_out_data[0] == decr_every_n_nan_or_inf) { + if (bad_out_data[0] >= decr_every_n_nan_or_inf) { const auto& runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, {*updated_loss_scaling_tensor}, {{"power", static_cast(1)}, @@ -60,13 +62,18 @@ void Update(const platform::NPUDeviceContext& ctx, std::vector new_loss_scaling; TensorToVector(*updated_loss_scaling_tensor, ctx, &new_loss_scaling); - if (new_loss_scaling[0] < static_cast(1)) { + float min_value = 1.0; + if (FLAGS_min_loss_scaling > 1) { + min_value = static_cast(FLAGS_min_loss_scaling); + } + + if (new_loss_scaling[0] < min_value) { // updated_loss_scaling_data = 1 - const auto& runner_p4 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, - {*updated_loss_scaling_tensor}, - {{"power", static_cast(1)}, - {"scale", static_cast(0)}, - {"shift", static_cast(1)}}); + const auto& runner_p4 = NpuOpRunner( + "Power", {*pre_loss_scaling_tensor}, {*updated_loss_scaling_tensor}, + {{"power", static_cast(1)}, + {"scale", static_cast(0)}, + {"shift", static_cast(min_value)}}); runner_p4.Run(stream); } @@ -93,7 +100,7 @@ void Update(const platform::NPUDeviceContext& ctx, std::vector good_out_data; TensorToVector(*good_out_tensor, ctx, &good_out_data); - if (good_out_data[0] == incr_every_n_steps) { + if (good_out_data[0] >= incr_every_n_steps) { const auto& runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, {*updated_loss_scaling_tensor}, {{"power", static_cast(1)}, diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 4f22d28a450c1..3467658e894d5 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -848,7 +848,8 @@ void BatchNormGradMaker::Apply(GradOpPtr op) const { } // used when setting use_global_stats True during training - if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { + if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats")) || + BOOST_GET_CONST(bool, this->GetAttr("is_test"))) { op->SetInput("Mean", this->Output("MeanOut")); op->SetInput("Variance", this->Output("VarianceOut")); } diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index 526fc7364cdd8..8499d1cdcd646 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel { W, epsilon, momentum, scale_data, bias_data, saved_mean_data, saved_variance_data, mean_out_data, variance_out_data, true); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_train_forward) return " - "wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The batch_norm XPU API return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } else { const auto* mean = ctx.Input("Mean"); const auto* variance = ctx.Input("Variance"); - const auto* mean_data = mean->data(); - const auto* variance_data = variance->data(); - int r = xpu::batch_norm_infer_forward( - dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data, - bias_data, mean_data, variance_data); + const auto* mean_data = mean->data(); + const auto* variance_data = variance->data(); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(ctx.GetPlace()); + int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, + H, W, epsilon, scale_data, bias_data, + mean_data, variance_data, true); PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_infer_forward) return " - "wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The batch_norm_infer XPU API return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } } }; diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 8b2998e52a172..06300817e0a12 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -13,47 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" +#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -// aligned vector generates vectorized load/store on CUDA -template -struct alignas(sizeof(T) * Size) AlignedVector { - T val[Size]; -}; - -template -inline int VectorizedSize(const T* pointer) { - uint64_t address = reinterpret_cast(pointer); - constexpr int vec4 = std::alignment_of>::value; // NOLINT - if (address % vec4 == 0) { - return 4; - } - return 1; -} - template __global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = AlignedVector; - using StoreT = AlignedVector; for (int64_t i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { - InT in_vec[VecSize]; - LoadT* in_value = reinterpret_cast(&in_vec); - *in_value = *reinterpret_cast(&in[i]); + LoadT in_val; + platform::Load(&in[i], &in_val); - OutT out_vec[VecSize]; + StoreT out_val; #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - out_vec[ii] = static_cast(in_vec[ii]); + for (int j = 0; j < VecSize; j++) { + out_val[j] = static_cast(in_val[j]); } - *(reinterpret_cast(&out[i])) = - *reinterpret_cast(&out_vec[0]); + platform::Store(out_val, &out[i]); } } @@ -78,7 +62,7 @@ struct CastOpFunctor { auto* out = out_->mutable_data(ctx_.GetPlace()); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx_, size); - int vec_size = VectorizedSize(out); + int vec_size = platform::GetVectorizedSize(out); if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { VecCastCUDAKernel<<< config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>( diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index c9cc01b8b17dc..d2addb32bca00 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -167,6 +167,12 @@ class CoalesceTensorOpKernel : public framework::OpKernel { auto out_tensors = context.MultiOutput("Output"); size_t offset = 0; if (context.Attr("copy_data")) { +#ifdef PADDLE_WITH_ASCEND_CL + framework::VisitDataType( + dtype, + FillConstantVisitor( + dev_ctx, fused_tensor, static_cast(0.0), dtype, context)); +#endif for (size_t i = 0; i < in_var_names.size(); ++i) { size_t len = static_cast(in_tensors[i]->numel()); auto sub_tensor = fused_tensor->Slice( diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 6095516f92fa5..4783aa3a86fb3 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -111,7 +111,8 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "use_mkldnn", "(bool, default false) Indicates if MKL-DNN kernel will be used") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr("axis", "The axis along which the input tensors will be concatenated." "The axis could also be negative numbers. Negative axis is " @@ -128,12 +129,14 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { "use_quantizer", "(bool, default false) " "This parameter is no longer used. Use 'mkldnn_data_type' instead.") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr( "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") .SetDefault("float32") - .InEnum({"float32", "int8", "bfloat16"}); + .InEnum({"float32", "int8", "bfloat16"}) + .AsExtra(); AddComment(R"DOC( Concat Operator. diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 382f412742e61..355e52b9436e6 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -36,10 +36,9 @@ struct float16; namespace paddle { namespace operators { -static void DataCopy(const framework::LoDTensor &src_item, +static void DeepCopy(const framework::LoDTensor &src_item, const std::string &fetch_var_name, - framework::LoDTensor *dst_item, - const platform::DeviceContext &dev_ctx) { + framework::LoDTensor *dst_item) { if (src_item.IsInitialized() && src_item.numel() > 0) { #ifdef PADDLE_WITH_MKLDNN // Conversion from MKL-DNN to Paddle @@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item, : paddle::platform::MKLDNNDeviceContext::tls() .get_cur_paddle_data_layout(), src_item, &out, platform::CPUPlace()); - TensorCopy(src_item, platform::CPUPlace(), dev_ctx, dst_item); + TensorCopySync(out, platform::CPUPlace(), dst_item); } else { - if (platform::is_gpu_place(src_item.place())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item); -#endif - } else { - TensorCopy(src_item, platform::CPUPlace(), dst_item); - } + TensorCopySync(src_item, platform::CPUPlace(), dst_item); } #else - if (platform::is_gpu_place(src_item.place())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item); + TensorCopySync(src_item, platform::CPUPlace(), dst_item); #endif - } else { - TensorCopy(src_item, platform::CPUPlace(), dst_item); - } -#endif - } else { // Not copy, if the src tensor is empty. dst_item->clear(); @@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel { const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { return framework::OpKernelType(expected_kernel_type.data_type_, - expected_kernel_type.place_, - tensor.layout()); + tensor.place(), tensor.layout()); } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + platform::CPUPlace()); } }; @@ -119,12 +104,10 @@ class FetchV2Kernel { if (fetch_var == nullptr) { return; } - PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of memcpy_d2h_op is not found.")); + PADDLE_ENFORCE_EQ( + ctx.HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of fetch_v2_op is not found.")); auto *out_var = ctx.OutputVar("Out"); - // Get dev_ctx from ExecutionContext, it's D2H stream - auto &dev_ctx = ctx.device_context(); int col = ctx.Attr("col"); PADDLE_ENFORCE_GE( @@ -140,10 +123,19 @@ class FetchV2Kernel { fetch_list->resize(col + 1); } + bool deepcopy = ctx.Attr("deepcopy"); + if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col))); - DataCopy(src_item, fetch_var_name, dst_item, dev_ctx); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item.place()), true, + platform::errors::InvalidArgument( + "Tensor's place of input(X) must be CPUPlace.")); + if (deepcopy) { + DeepCopy(src_item, fetch_var_name, dst_item); + } else { + dst_item->ShareDataWith(src_item); + } } else { auto &src_item = fetch_var->Get(); framework::LoDTensorArray tmp(src_item.size()); @@ -151,7 +143,14 @@ class FetchV2Kernel { auto &dst_item = BOOST_GET(framework::LoDTensorArray, fetch_list->at(col)); for (size_t i = 0; i < src_item.size(); ++i) { - DataCopy(src_item[i], fetch_var_name, &dst_item[i], dev_ctx); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item[i].place()), true, + platform::errors::InvalidArgument( + "Tensor's place of input(X) must be CPUPlace.")); + if (deepcopy) { + DeepCopy(src_item[i], fetch_var_name, &dst_item[i]); + } else { + dst_item[i].ShareDataWith(src_item[i]); + } } } } @@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker { "(vector) A fetching list of LoDTensor which may have " "different dimension, shape and data type."); AddAttr("col", "(int) The column index of fetching object."); + AddAttr("deepcopy", "(bool) Whether deep copy is required.") + .SetDefault(true); AddComment(R"DOC( FetchV2 Operator. @@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double, int64_t, ops::FetchV2Kernel, bool, ops::FetchV2Kernel, plat::float16, ops::FetchV2Kernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM) -REGISTER_OP_CUDA_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double, - ops::FetchV2Kernel, int, ops::FetchV2Kernel, - int64_t, ops::FetchV2Kernel, bool, - ops::FetchV2Kernel, plat::float16, - ops::FetchV2Kernel); -#endif - -#ifdef PADDLE_WITH_ASCEND_CL -REGISTER_OP_NPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double, - ops::FetchV2Kernel, int, ops::FetchV2Kernel, - int64_t, ops::FetchV2Kernel, bool, - ops::FetchV2Kernel, plat::float16, - ops::FetchV2Kernel); -#endif diff --git a/paddle/fluid/operators/conv_transpose_op_npu.cc b/paddle/fluid/operators/conv_transpose_op_npu.cc new file mode 100644 index 0000000000000..6cb431b873a35 --- /dev/null +++ b/paddle/fluid/operators/conv_transpose_op_npu.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/conv_transpose_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class Conv2DTransposeNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // input + const Tensor* input = context.Input("Input"); + const Tensor* filter = context.Input("Filter"); + // output + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + // attr + std::vector output_padding = + context.Attr>("output_padding"); + const std::vector stride = context.Attr>("strides"); + std::vector padding = context.Attr>("paddings"); + std::vector dilation = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + int groups = context.Attr("groups"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + // npu stream + auto stream = + context.template device_context().stream(); + + // check dimension + const bool channel_last = data_format == "NHWC"; + + // update padding and dilation + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + framework::DDim in_data_dims; + framework::DDim filter_data_dims; + + if (channel_last) { + in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); + } else { + in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); + } + filter_data_dims = framework::slice_ddim(filter_dims, 2, in_dims.size()); + + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, + in_data_dims, stride, ksize); + + // construct NPU attr + std::vector strides(4, 1); + std::vector dilations(4, 1); + + Tensor input_tensor, output_tensor; + input_tensor.ShareDataWith(*input); + output_tensor.ShareDataWith(*output); + + if (channel_last) { + input_tensor.set_layout(DataLayout::kNHWC); + output_tensor.set_layout(DataLayout::kNHWC); + strides[1] = stride[0]; + strides[2] = stride[1]; + dilations[1] = dilation[0]; + dilations[2] = dilation[1]; + } else { + strides[2] = stride[0]; + strides[3] = stride[1]; + dilations[2] = dilation[0]; + dilations[3] = dilation[1]; + } + + for (auto i = output_padding.size(); i < 4; ++i) { + output_padding.insert(output_padding.begin(), 0); + } + auto output_dim_vec = framework::vectorize(output_tensor.dims()); + // CANN OP + const auto& runner = + NpuOpRunner("Conv2DTransposeD", {input_tensor, *filter}, + {output_tensor}, {{"input_size", output_dim_vec}, + {"strides", strides}, + {"dilations", dilations}, + {"output_padding", output_padding}, + {"groups", groups}, + {"pads", padding}, + {"data_format", data_format}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +// conv2d +REGISTER_OP_NPU_KERNEL(conv2d_transpose, ops::Conv2DTransposeNPUKernel, + ops::Conv2DTransposeNPUKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index fbc145d3123d5..958f037a04f3b 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -38,7 +38,7 @@ namespace operators { template __global__ void RandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, const T* src, - MaskType* mask_data, T* dst, + MaskType* mask, T* dst, bool is_upscale_in_train, uint64_t increment) { int idx = blockDim.x * blockIdx.x + threadIdx.x; #ifdef PADDLE_WITH_HIP @@ -49,36 +49,36 @@ __global__ void RandomGenerator(const size_t n, uint64_t seed, curand_init(seed, idx, increment, &state); #endif - MaskType mask; - T dest; + MaskType mask_val; + T dst_val; + T factor = static_cast(1.0f / (1.0f - dropout_prob)); for (; idx < n; idx += blockDim.x * gridDim.x) { - T s = src[idx]; + T src_val = src[idx]; #ifdef PADDLE_WITH_HIP if (hiprand_uniform(&state) < dropout_prob) { #else if (curand_uniform(&state) < dropout_prob) { #endif - mask = 0; - dest = 0; + mask_val = 0; + dst_val = 0; } else { - mask = 1; - if (is_upscale_in_train) { - dest = s / static_cast(1.0f - dropout_prob); - } else { - dest = s; - } + mask_val = 1; + dst_val = is_upscale_in_train ? src_val * factor : src_val; } - mask_data[idx] = mask; - dst[idx] = dest; + mask[idx] = mask_val; + dst[idx] = dst_val; } } template __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, const float dropout_prob, - const T* src, MaskType* mask_data, - T* dst, bool is_upscale_in_train, + const T* src, MaskType* mask, T* dst, + bool is_upscale_in_train, uint64_t increment) { + using LoadT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + #ifdef PADDLE_WITH_HIP int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; hiprandStatePhilox4_32_10_t state; @@ -89,43 +89,33 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, curand_init(seed, idx, increment, &state); #endif - MaskType mask; - T dest; - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; T factor = static_cast(1.0f / (1.0f - dropout_prob)); for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { - T src_vec[VecSize]; - LoadT* value = reinterpret_cast(&src_vec); - *value = *reinterpret_cast(&src[i]); + LoadT src_val; + platform::Load(&src[i], &src_val); + #ifdef PADDLE_WITH_HIP float4 rand = hiprand_uniform4(&state); #else float4 rand = curand_uniform4(&state); #endif - T dest_vec[VecSize]; - MaskType mask_vec[VecSize]; + LoadT dst_val; + MaskLoadT mask_val; #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - if ((&rand.x)[ii] < dropout_prob) { - dest_vec[ii] = 0; - mask_vec[ii] = 0; + for (int j = 0; j < VecSize; j++) { + if ((&rand.x)[j] < dropout_prob) { + dst_val[j] = 0; + mask_val[j] = 0; } else { - if (is_upscale_in_train) { - dest_vec[ii] = src_vec[ii] * factor; - } else { - dest_vec[ii] = src_vec[ii]; - } - mask_vec[ii] = 1; + dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j]; + mask_val[j] = 1; } } - *(reinterpret_cast(&dst[i])) = - *reinterpret_cast(&dest_vec[0]); - *(reinterpret_cast(&mask_data[i])) = - *reinterpret_cast(&mask_vec[0]); + platform::Store(dst_val, &dst[i]); + platform::Store(mask_val, &mask[i]); } } @@ -185,7 +175,7 @@ class GPUDropoutKernel : public framework::OpKernel { // same as the previous calls. uint64_t seed_data; uint64_t increment; - int vec_size = VectorizedSize(x_data); + int vec_size = platform::GetVectorizedSize(x_data); auto offset = ((x_numel - 1) / (config.block_per_grid.x * config.thread_per_block.x * vec_size) + 1) * diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 997a7d835aa37..96e6725212cc6 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -21,54 +21,36 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -// aligned vector generates vectorized load/store on CUDA -template -struct alignas(sizeof(T) * Size) AlignedVector { - T val[Size]; -}; - -template -inline int VectorizedSize(const T* pointer) { - uint64_t address = reinterpret_cast(pointer); - constexpr int vec4 = std::alignment_of>::value; // NOLINT - if (address % vec4 == 0) { - return 4; - } - return 1; -} - #if defined(__NVCC__) || defined(__HIPCC__) template __global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, const T factor, const int64_t size, T* dx) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; + using LoadT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { - T dout_vec[VecSize]; - LoadT* dout_value = reinterpret_cast(&dout_vec); - *dout_value = *reinterpret_cast(&dout[i]); + LoadT dout_val; + platform::Load(&dout[i], &dout_val); - MaskType mask_vec[VecSize]; - MaskLoadT* mask_value = reinterpret_cast(&mask_vec); - *mask_value = *reinterpret_cast(&mask[i]); + MaskLoadT mask_val; + platform::Load(&mask[i], &mask_val); - T dx_vec[VecSize]; + LoadT dx_val; #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + for (int j = 0; j < VecSize; j++) { + dx_val[j] = dout_val[j] * static_cast(mask_val[j]) * factor; } - *(reinterpret_cast(&dx[i])) = *reinterpret_cast(&dx_vec[0]); + platform::Store(dx_val, &dx[i]); } } #endif @@ -187,7 +169,7 @@ class DropoutGradKernel : public framework::OpKernel { if (dropout_prob == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { - int vec_size = VectorizedSize(grad_y->data()); + int vec_size = platform::GetVectorizedSize(grad_y->data()); if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 && size % 4 == 0) { #if defined(__NVCC__) || defined(__HIPCC__) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e09f94a6c0fee..d6cf58f7a157f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -85,6 +85,14 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); int max_dim = std::max(x_dims.size(), y_dims.size()); int axis = ctx->Attrs().Get("axis"); + if (x_dims.size() == y_dims.size()) { + PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true, + platform::errors::InvalidArgument( + "axis should be -1 or 0 while the dimension of " + "tensor X (%s) is equal to the dimension of " + "tensor Y (%s), but received axis: %s", + x_dims.size(), y_dims.size(), axis)); + } PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true, platform::errors::InvalidArgument( "The axis range must be [%s, %s), but axis is %s. " diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 95dc6ed342ffc..17cf7c762def2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -199,8 +199,8 @@ struct StridesCalculation { template struct BroadcastArgsWrapper { - using InVecType = platform::CudaAlignedVector; - using OutVecType = platform::CudaAlignedVector; + using InVecType = platform::AlignedVector; + using OutVecType = platform::AlignedVector; OutT *out_data; OutVecType *vec_out_data; @@ -320,7 +320,7 @@ template __device__ inline void VectorizedBroadcastKernelImpl( BroadcastArgsWrapper broadcast_wrapper, int tid) { - using OutVecType = platform::CudaAlignedVector; + using OutVecType = platform::AlignedVector; OutVecType args_out; InT ins[ET]; InT args[ET][VecSize]; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 3bd746ace0610..1b680cfc995a5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -69,8 +69,8 @@ int GetVectorizedSizeForIO(const std::vector &ins, template struct ElementwiseDataWrapper { - using InVecType = platform::CudaAlignedVector; - using OutVecType = platform::CudaAlignedVector; + using InVecType = platform::AlignedVector; + using OutVecType = platform::AlignedVector; const InT *__restrict__ in_data[ET]; OutT *out_data; @@ -117,8 +117,8 @@ template __device__ inline void VectorizedKernelImpl(ElementwiseWrapper data, Functor func, int tid) { - using InVecType = platform::CudaAlignedVector; - using OutVecType = platform::CudaAlignedVector; + using InVecType = platform::AlignedVector; + using OutVecType = platform::AlignedVector; InVecType ins_vec[ET]; OutVecType out_vec; InT *ins_ptr[ET]; diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc index e0763d769f047..85b247781a40d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/operators/elementwise/elementwise_npu.h" #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -27,21 +28,198 @@ template class ElementwisePowNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); auto place = ctx.GetPlace(); + int axis = ctx.Attr("axis"); out->mutable_data(place); - auto stream = - ctx.template device_context() - .stream(); + bool direct_compute = false; + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = + (axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis); + if (x_dims.size() >= y_dims.size()) { + direct_compute = + y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + } else { + direct_compute = + x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + } + + auto stream = dev_ctx.stream(); + + if (direct_compute) { + const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {}); + runner.Run(stream); + } else { + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + const auto& runner = + NpuOpRunner("Pow", {transformed_x, transformed_y}, {*out}, {}); + runner.Run(stream); + } + } +}; - const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {}); - runner.Run(stream); +template +class ElementwisePowGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + auto place = ctx.GetPlace(); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = + (axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis); + Tensor transformed_x, transformed_y; + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + + auto dout_dims = dout->dims(); + auto stream = dev_ctx.stream(); + // Reshape info vector. + std::vector reduce_axes; + if (dx) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, place); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dx->mutable_data(place); + Tensor tmp_dx; + tmp_dx.mutable_data(dout_dims, place); + + // dx = dout * y * pow(x, y - 1); + Tensor PowGrad_dx_temp1(dout->type()); + PowGrad_dx_temp1.mutable_data(dout->dims(), place); + const auto& runner_PowGrad_dx_temp1 = + NpuOpRunner("Mul", {*dout, transformed_y}, {PowGrad_dx_temp1}, {}); + runner_PowGrad_dx_temp1.Run(stream); + + Tensor one_dx(transformed_y.type()); + one_dx.mutable_data(transformed_y.dims(), place); + const auto& runner_one_dx = + NpuOpRunner("OnesLike", {transformed_y}, {one_dx}, {}); + runner_one_dx.Run(stream); + + Tensor sub_dx(transformed_y.type()); + sub_dx.mutable_data(transformed_y.dims(), place); + const auto& runner_sub_dx = + NpuOpRunner("Sub", {transformed_y, one_dx}, {sub_dx}, {}); + runner_sub_dx.Run(stream); + + Tensor PowGrad_dx_temp2(transformed_x.type()); + PowGrad_dx_temp2.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dx_temp2 = + NpuOpRunner("Pow", {transformed_x, sub_dx}, {PowGrad_dx_temp2}, {}); + runner_PowGrad_dx_temp2.Run(stream); + + const auto& runner_dx = NpuOpRunner( + "Mul", {PowGrad_dx_temp1, PowGrad_dx_temp2}, {tmp_dx}, {}); + runner_dx.Run(stream); + + if (x_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + x_dims.size()) || + (dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dx, place, dev_ctx, dx); + } + } + if (dy) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(dout_dims, place); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + dy->mutable_data(place); + Tensor tmp_dy; + tmp_dy.mutable_data(dout_dims, place); + + // dy = dout * log(x) * pow(x, y) + Tensor PowGrad_dy_temp1(transformed_x.type()); + PowGrad_dy_temp1.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dy_temp1 = NpuOpRunner( + "Pow", {transformed_x, transformed_y}, {PowGrad_dy_temp1}, {}); + runner_PowGrad_dy_temp1.Run(stream); + + Tensor one_dy(transformed_x.type()); + one_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_one_dy = + NpuOpRunner("OnesLike", {transformed_x}, {one_dy}, {}); + runner_one_dy.Run(stream); + + Tensor sub_dy(transformed_x.type()); + sub_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_sub_dy = + NpuOpRunner("Sub", {transformed_x, one_dy}, {sub_dy}, {}); + runner_sub_dy.Run(stream); + + Tensor log_dy(transformed_x.type()); + log_dy.mutable_data(transformed_x.dims(), place); + const auto& runner_log_dy = NpuOpRunner("Log1p", {sub_dy}, {log_dy}, {}); + runner_log_dy.Run(stream); + + Tensor PowGrad_dy_temp2(transformed_x.type()); + PowGrad_dy_temp2.mutable_data(transformed_x.dims(), place); + const auto& runner_PowGrad_dy_temp2 = NpuOpRunner( + "Mul", {log_dy, PowGrad_dy_temp1}, {PowGrad_dy_temp2}, {}); + runner_PowGrad_dy_temp2.Run(stream); + + const auto& runner_dy = + NpuOpRunner("Mul", {*dout, PowGrad_dy_temp2}, {tmp_dy}, {}); + runner_dy.Run(stream); + + if (y_dims != dout_dims) { + reduce_axes.clear(); + + int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + y_dims.size()) || + (dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } + } + if (!reduce_axes.empty()) { + const auto& runner = + NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy}, + {{"axes", reduce_axes}, {"keep_dims", false}}); + runner.Run(stream); + } + } else { + framework::TensorCopy(tmp_dy, place, dev_ctx, dy); + } + } + if (!dx && !dy) { + PADDLE_THROW(platform::errors::Unavailable( + "Not support all outputs to be empty.")); + } } }; @@ -49,9 +227,18 @@ class ElementwisePowNPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( elementwise_pow, - ops::ElementwisePowNPUKernel, - ops::ElementwisePowNPUKernel); + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel, + ops::ElementwisePowNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_pow_grad, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel, + ops::ElementwisePowGradNPUKernel); diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index c94ce4174f2be..778bab9f4dd26 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -55,9 +55,17 @@ class FlattenOp : public framework::OperatorWithKernel { int64_t outer = 1, inner = 1; for (int i = 0; i < in_dims.size(); ++i) { if (i < axis) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } else { - inner *= in_dims[i]; + if (in_dims[i] == -1 || inner == -1) { + inner = -1; + } else { + inner *= in_dims[i]; + } } } std::vector out_shape(2); @@ -296,7 +304,11 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { out_shape.push_back(in_dims[i]); } for (int i = start_axis; i <= stop_axis; i++) { - outer *= in_dims[i]; + if (in_dims[i] == -1 || outer == -1) { + outer = -1; + } else { + outer *= in_dims[i]; + } } out_shape.push_back(outer); for (int i = stop_axis + 1; i < in_dims_size; i++) { diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 2e98a7f332d16..ddefd8964af97 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -96,36 +96,13 @@ __global__ void BroadcastKernelBinary( kernel_primitives::WriteData(out + fix, result, num); } -template -int GetVectorizedSizeImpl(const T* pointer) { - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - uint64_t address = reinterpret_cast(pointer); - constexpr int vec8 = - std::alignment_of>::value; // NOLINT - constexpr int vec4 = - std::alignment_of>::value; // NOLINT - constexpr int vec2 = - std::alignment_of>::value; // NOLINT - if (address % vec8 == 0) { - // Note: this line can change from 4 to 8 if it can improve the performance. - return std::min(4, valid_vec_size); - } else if (address % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - // bias add forward impl for "[m, n] + [n] = [m, n]" template void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n, const T* in0, const T* in1, T* out) { - int in_vec_size = - std::min(GetVectorizedSizeImpl(in0), GetVectorizedSizeImpl(in1)); - int out_vec_size = std::min(4, GetVectorizedSizeImpl(out)); + int in_vec_size = std::min(platform::GetVectorizedSize(in0), + platform::GetVectorizedSize(in1)); + int out_vec_size = std::min(4, platform::GetVectorizedSize(out)); int vec_size = std::min(out_vec_size, in_vec_size); int numel = m * n; diff --git a/paddle/fluid/operators/gather_nd_op_npu.cc b/paddle/fluid/operators/gather_nd_op_npu.cc new file mode 100644 index 0000000000000..d04e0bce36fab --- /dev/null +++ b/paddle/fluid/operators/gather_nd_op_npu.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather_nd_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class GatherNdNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *out = ctx.Output("Out"); + + out->template mutable_data(ctx.GetPlace()); + + if (x->numel() == 0) return; + + if (index->numel() == 0) { + framework::TensorCopy(*x, ctx.GetPlace(), ctx.device_context(), out); + return; + } + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + const auto &runner = NpuOpRunner("GatherNd", {*x, *index}, {*out}, {}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class GatherNdGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *index = ctx.Input("Index"); + auto *x = ctx.Input("X"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *p = dx->mutable_data(ctx.GetPlace()); + + if (dx->numel() == 0) return; + + if (index->numel() == 0) { + framework::TensorCopy(*dout, ctx.GetPlace(), ctx.device_context(), dx); + return; + } + + framework::Tensor tmp_tensor(index->type()); + framework::Tensor tmp_tensor2(dout->type()); + const auto index_dims = index->dims(); + if (index_dims.size() == 1) { + tmp_tensor.ShareDataWith(*index); + std::vector new_dim = {1, index_dims[0]}; + tmp_tensor.Resize(framework::make_ddim(new_dim)); + index = &tmp_tensor; + + tmp_tensor2.ShareDataWith(*dout); + std::vector new_dim2{1}; + for (int i = index->numel(); i < x->dims().size(); i++) { + new_dim2.push_back(x->dims()[i]); + } + tmp_tensor2.Resize(framework::make_ddim(new_dim2)); + dout = &tmp_tensor2; + } + + auto stream = + ctx.template device_context() + .stream(); + + platform::NPUMemsetAsync(static_cast(p), 0, dx->numel() * sizeof(T), + stream); + + const auto &runner_scatter = NpuOpRunner( + "ScatterNdAdd", {*dx, *index, *dout}, {*dx}, {{"use_locking", false}}); + runner_scatter.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + gather_nd, ops::GatherNdNPUKernel, + ops::GatherNdNPUKernel); + +REGISTER_OP_NPU_KERNEL( + gather_nd_grad, + ops::GatherNdGradNPUKernel, + ops::GatherNdGradNPUKernel); diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index ea28c204ec9cf..d35b066be85e7 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -138,7 +138,8 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { "In backward process, calc the grad when has same index," "If true, update the grad using the overwrite mode in same index," "If false, using the accumulate mode in same index.") - .SetDefault(true); + .SetDefault(true) + .AsExtra(); AddAttr( "axis", "The Tensor which contains the axis that we do gather operation.") diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index a75ea538f2556..0b410f07fcb57 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -100,7 +100,8 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "use_cudnn", "(bool, default true) Only used in cudnn kernel, need install cudnn") - .SetDefault(true); + .SetDefault(true) + .AsExtra(); AddAttr( "align_corners", diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index dc82e4fa754eb..a2d61695649dc 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -33,13 +33,15 @@ class GRUOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "GRU"); OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output", - "BatchResetHiddenPrev", "GRU"); - OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", - "GRU"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "GRU"); - + bool is_test = ctx->Attrs().Get("is_test"); + if (!is_test) { + OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "GRU"); + OP_INOUT_CHECK(ctx->HasOutput("BatchResetHiddenPrev"), "Output", + "BatchResetHiddenPrev", "GRU"); + OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden", + "GRU"); + } auto input_dims = ctx->GetInputDim("Input"); auto weight_dims = ctx->GetInputDim("Weight"); int input_size = input_dims[1]; @@ -84,9 +86,11 @@ class GRUOp : public framework::OperatorWithKernel { "[%d, %d] (Bias) vs [1, %d] (frame_size * 3).", bias_height, bias_width, frame_size * 3)); } - ctx->SetOutputDim("BatchGate", input_dims); - ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); - ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); + if (!is_test) { + ctx->SetOutputDim("BatchGate", input_dims); + ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); + ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); + } ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); ctx->ShareLoD("Input", "Hidden"); } @@ -124,19 +128,22 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "organized in batches. The LoD size is 2. The first LoD contains " "the batch offsets and the second LoD contains the indexes in " "the raw sequence data.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "BatchResetHiddenPrev", "(LoDTensor) The reset hidden state LoDTensor organized in batches. " "This LoDTensor is a matrix with shape (T X D) and has the same LoD " "with `BatchGate`.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "BatchHidden", "(LoDTensor) The hidden state LoDTensor organized in batches. " "This LoDTensor is a matrix with shape (T X D) and has the same LoD " "with `BatchGate`.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput( "Hidden", "(LoDTensor) the hidden state LoDTensor organized in sequences. " @@ -155,6 +162,9 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default: False) " "whether to compute reversed GRU.") .SetDefault(false); + AddAttr("is_test", "True if in test phase.") + .SetDefault(false) + .AsExtra(); AddAttr("origin_mode", "bool" "use origin mode in article https://arxiv.org/abs/1412.3555") @@ -269,24 +279,42 @@ class GRUCPUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { using DeviceContext = paddle::platform::CPUDeviceContext; + using LodTensorPtr = LoDTensor*; + bool is_test = context.Attr("is_test"); + bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); + auto input_dims = input->dims(); auto hidden_dims = hidden->dims(); + LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; + LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = context.Output("BatchGate"); + batch_hidden = context.Output("BatchHidden"); + batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + } + batch_gate->mutable_data(context.GetPlace()); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + batch_hidden->mutable_data(context.GetPlace()); + bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/gru_op.cu.cc b/paddle/fluid/operators/gru_op.cu.cc index bdc5debaea790..edd7f8a7cf553 100644 --- a/paddle/fluid/operators/gru_op.cu.cc +++ b/paddle/fluid/operators/gru_op.cu.cc @@ -28,24 +28,42 @@ template class GRUKernel : public framework::OpKernel { public: void BatchCompute(const framework::ExecutionContext& context) const { + using LodTensorPtr = LoDTensor*; + + bool is_test = context.Attr("is_test"); bool origin_mode = context.Attr("origin_mode"); auto* input = context.Input("Input"); auto* h0 = context.Input("H0"); auto* weight = context.Input("Weight"); const T* weight_data = weight->data(); auto* bias = context.Input("Bias"); - auto* batch_gate = context.Output("BatchGate"); - batch_gate->mutable_data(context.GetPlace()); - auto* batch_reset_hidden_prev = - context.Output("BatchResetHiddenPrev"); - batch_reset_hidden_prev->mutable_data(context.GetPlace()); - auto* batch_hidden = context.Output("BatchHidden"); - batch_hidden->mutable_data(context.GetPlace()); auto* hidden = context.Output("Hidden"); hidden->mutable_data(context.GetPlace()); + auto input_dims = input->dims(); auto hidden_dims = hidden->dims(); + LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; + LoDTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp; + if (is_test) { + batch_gate = &batch_gate_tmp; + batch_gate->Resize(input_dims); + + batch_reset_hidden_prev = &batch_reset_hidden_prev_tmp; + batch_reset_hidden_prev->Resize(hidden_dims); + + batch_hidden = &batch_hidden_tmp; + batch_hidden->Resize(hidden_dims); + } else { + batch_gate = context.Output("BatchGate"); + batch_hidden = context.Output("BatchHidden"); + batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + } + batch_gate->mutable_data(context.GetPlace()); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + batch_hidden->mutable_data(context.GetPlace()); + bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = context.template device_context(); diff --git a/paddle/fluid/operators/label_smooth_op_npu.cc b/paddle/fluid/operators/label_smooth_op_npu.cc new file mode 100644 index 0000000000000..a20b7f06d794e --- /dev/null +++ b/paddle/fluid/operators/label_smooth_op_npu.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/label_smooth_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +void LabelSmoothMuls(const platform::Place& place, const aclrtStream& stream, + const Tensor* in, float val, Tensor* out) { + out->mutable_data(in->dims(), place); + const auto& runner = NpuOpRunner("Muls", {*in}, {*out}, {{"value", val}}); + runner.Run(stream); +} + +template +void LabelSmoothAdds(const platform::Place& place, const aclrtStream& stream, + const Tensor* in, float val, Tensor* out) { + out->mutable_data(in->dims(), place); + const auto& runner = NpuOpRunner("Adds", {*in}, {*out}, {{"value", val}}); + runner.Run(stream); +} + +template +void LabelSmoothAddBroadCast(const platform::Place& place, + const aclrtStream& stream, const Tensor* in1, + const Tensor* in2, Tensor* out) { + out->mutable_data(place); + const auto& runner = NpuOpRunner("AddV2", {*in1, *in2}, {*out}, {}); + runner.Run(stream); +} + +template +class LabelSmoothNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto epsilon = ctx.Attr("epsilon"); + + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + + if (dist_t) { + Tensor tmp; + Tensor dist; + Tensor tmp2; + LabelSmoothMuls(place, stream, in_t, (1 - epsilon), &tmp); + LabelSmoothMuls(place, stream, dist_t, epsilon, &tmp2); + tmp2.Resize({1, label_dim}); + LabelSmoothAddBroadCast(place, stream, &tmp, &tmp2, out_t); + } else { + Tensor tmp; + LabelSmoothMuls(place, stream, in_t, (1 - epsilon), &tmp); + LabelSmoothAdds(place, stream, &tmp, (epsilon / label_dim), out_t); + } + } +}; + +template +class LabelSmoothGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* d_in_t = ctx.Output(framework::GradVarName("X")); + auto epsilon = ctx.Attr("epsilon"); + + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + + LabelSmoothMuls(place, stream, d_out_t, 1 - epsilon, d_in_t); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(label_smooth, ops::LabelSmoothNPUKernel, + ops::LabelSmoothNPUKernel); +REGISTER_OP_NPU_KERNEL(label_smooth_grad, ops::LabelSmoothGradNPUKernel, + ops::LabelSmoothGradNPUKernel); diff --git a/paddle/fluid/operators/label_smooth_op_xpu.cc b/paddle/fluid/operators/label_smooth_op_xpu.cc new file mode 100644 index 0000000000000..6b6350753909f --- /dev/null +++ b/paddle/fluid/operators/label_smooth_op_xpu.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/label_smooth_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LabelSmoothXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; + auto ptr = out_t->mutable_data(ctx.GetPlace()); + + auto epsilon = ctx.Attr("epsilon"); + auto& dev_ctx = ctx.template device_context(); + if (dist_t) { + PADDLE_THROW( + platform::errors::External("XPU doesn't support dist label smooth")); + } else { + int r = xpu::label_smooth(dev_ctx.x_context(), in_t->data(), ptr, + in_t->numel(), epsilon, label_dim); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(label_smooth) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + label_smooth, + ops::LabelSmoothXPUKernel); +#endif diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 42048ff373368..444478c2eadab 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -179,16 +179,19 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { }); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr( "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") .SetDefault("float32") - .InEnum({"float32", "bfloat16"}); + .InEnum({"float32", "bfloat16"}) + .AsExtra(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Assume feature vectors exist on dimensions diff --git a/paddle/fluid/operators/log_softmax_op_npu.cc b/paddle/fluid/operators/log_softmax_op_npu.cc new file mode 100644 index 0000000000000..d955bef6ce2ac --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op_npu.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/log_softmax_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" +namespace paddle { +namespace operators { +template +class LogSoftmaxNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Out = ctx.Output("Out"); + const int rank = X->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + std::vector axes; + axes.push_back(axis); + framework::NPUAttributeMap attr_input = {{"axes", axes}}; + Out->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("LogSoftmaxV2", {*X}, {*Out}, attr_input); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + log_softmax, + ops::LogSoftmaxNPUKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index f1bb9a985f4c1..2bda2eb889123 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -81,10 +81,12 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("is_sparse", "(boolean, default false) " "Sparse update.") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr("is_distributed", "(boolean, default false) distributed lookup table.") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr("padding_idx", "(int64, default -1) " "If the value is -1, it makes no effect to lookup. " @@ -93,22 +95,27 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(kNoPadding); // for parameter prefetch - AddAttr("remote_prefetch", "").SetDefault(false); - AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); + AddAttr("remote_prefetch", "").SetDefault(false).AsExtra(); + AddAttr("trainer_id", "trainer id from 0 ~ worker_num.") + .SetDefault(0) + .AsExtra(); AddAttr>("height_sections", "Height for each output SelectedRows.") - .SetDefault(std::vector({})); + .SetDefault(std::vector({})) + .AsExtra(); AddAttr>( "epmap", "(string vector, default 127.0.0.1:6164)" "Server endpoints in the order of input variables for mapping") - .SetDefault({}); + .SetDefault({}) + .AsExtra(); AddAttr>( "table_names", "(string vector, the split table names that will be fetched from " "parameter server)" "in the order of input variables for mapping") - .SetDefault({}); + .SetDefault({}) + .AsExtra(); AddComment(R"DOC( Lookup Table V2 Operator. diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index 2c9669cbd6549..0405578f5dc1e 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -30,10 +30,15 @@ class LSTMOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM"); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM"); - OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM"); - OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output", - "BatchCellPreAct", "LSTM"); + bool is_test = ctx->Attrs().Get("is_test"); + + if (!is_test) { + OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", + "LSTM"); + OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output", + "BatchCellPreAct", "LSTM"); + } auto in_dims = ctx->GetInputDim("Input"); PADDLE_ENFORCE_EQ( in_dims.size(), 2, @@ -103,8 +108,10 @@ class LSTMOp : public framework::OperatorWithKernel { framework::DDim out_dims({in_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchGate", in_dims); - ctx->SetOutputDim("BatchCellPreAct", out_dims); + if (!is_test) { + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); + } ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } @@ -164,11 +171,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "LoD is the batch offsets and the second LoD contains the " "indexes, which denote the position of reorganized sequence " "in the raw input.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddOutput("BatchCellPreAct", "(LoDTensor) This LoDTensor is obtained in the forward and used " "in the backward.") - .AsIntermediate(); + .AsIntermediate() + .AsExtra(); AddAttr("use_peepholes", "(bool, default: True) " "whether to enable diagonal/peephole connections.") @@ -177,6 +186,9 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default: False) " "whether to compute reversed LSTM.") .SetDefault(false); + AddAttr("is_test", "True if in test phase.") + .SetDefault(false) + .AsExtra(); AddAttr( "gate_activation", "(string, default: sigmoid)" diff --git a/paddle/fluid/operators/lstm_op.h b/paddle/fluid/operators/lstm_op.h index a4434283abb6f..c6f43b949a736 100644 --- a/paddle/fluid/operators/lstm_op.h +++ b/paddle/fluid/operators/lstm_op.h @@ -40,6 +40,8 @@ template class LSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + bool is_test = ctx.Attr("is_test"); + auto* input = ctx.Input("Input"); auto* weight = ctx.Input("Weight"); auto* bias = ctx.Input("Bias"); @@ -47,7 +49,14 @@ class LSTMKernel : public framework::OpKernel { auto* hidden_t0 = ctx.Input("H0"); auto* cell_t0 = ctx.Input("C0"); - auto* batch_gate = ctx.Output("BatchGate"); + LoDTensor* batch_gate = nullptr; + LoDTensor batch_gate_temp; + if (is_test) { + batch_gate = &batch_gate_temp; + batch_gate->Resize(input->dims()); + } else { + batch_gate = ctx.Output("BatchGate"); + } batch_gate->mutable_data(ctx.GetPlace()); auto* hidden_out = ctx.Output("Hidden"); hidden_out->mutable_data(ctx.GetPlace()); @@ -99,8 +108,13 @@ class LSTMKernel : public framework::OpKernel { } // Use the local variable as here. - LoDTensor batch_hidden, batch_cell; - auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + LoDTensor batch_hidden, batch_cell, batch_cell_pre_act_temp; + LoDTensor* batch_cell_pre_act; + if (is_test) { + batch_cell_pre_act = &batch_cell_pre_act_temp; + } else { + batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + } batch_hidden.mutable_data(dims, ctx.GetPlace()); batch_cell.mutable_data(dims, ctx.GetPlace()); batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 7c5f59fab0d28..6da1bfb964f24 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -31,18 +31,43 @@ namespace operators { namespace math { template -__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { +static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) { typedef cub::WarpReduce WarpReduce; typename WarpReduce::TempStorage temp_storage; + val = WarpReduce(temp_storage).Sum(val, warp_size); + return val; +} -#ifdef __HIPCC__ - int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); - value = WarpReduce(temp_storage).Sum(value, block_size); -#else - value = WarpReduce(temp_storage).Sum(value); -#endif +template +__forceinline__ __device__ T BlockReduceSum(T val) { + static __shared__ T shared[32]; + int thread_id = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; + int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); + int lane = thread_id % warp_size; + int wid = thread_id / warp_size; + + val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction + + if (lane == 0) shared[wid] = val; // Write reduced value to shared memory + __syncthreads(); // Wait for all partial reductions + + // read from shared memory only if that warp existed + int block_size = blockDim.x * blockDim.y * blockDim.z; + if (thread_id < (block_size - 1) / warp_size + 1) { + val = shared[lane]; + } else { + val = static_cast(0); + } - if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); + if (wid == 0) { + val = WarpReduceSum(val, warp_size); // Final reduce within first warp + } + __syncthreads(); + if (thread_id != 0) { + val = static_cast(0); + } + return val; } #define ARG_DEFINE_KernelDepthwiseConv \ @@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( } } } - CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); + + T val = BlockReduceSum(s); + platform::CudaAtomicAdd(&filter_grad_data[gbid], val); } template @@ -892,6 +919,7 @@ class DepthwiseConvFunctor 1024 && output_width <= 2048) thread = (output_width - 1) / 2 + 1; @@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor 1024 && input_width <= 2048) { thread = (input_width - 1) / 2 + 1; diff --git a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc new file mode 100644 index 0000000000000..e16c41829b1a6 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc @@ -0,0 +1,175 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class SliceMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("Input"); + auto* out = ctx.Output("Out"); + + auto x_vec_dims = framework::vectorize(x->dims()); + auto out_vec_dims = framework::vectorize(out->dims()); + + auto axes_int = ctx.Attr>("axes"); + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + + std::vector axes(ctx.Attr>("axes").begin(), + ctx.Attr>("axes").end()); + std::vector starts(ctx.Attr>("starts").begin(), + ctx.Attr>("starts").end()); + std::vector ends(ctx.Attr>("ends").begin(), + ctx.Attr>("ends").end()); + + auto decrease_axis = ctx.Attr>("decrease_axis"); + + std::vector offsets(x_vec_dims.size(), 0); + std::vector slice_dims(x_vec_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + starts[i] = starts[i] < 0 ? x_vec_dims[axes[i]] + starts[i] : starts[i]; + ends[i] = ends[i] < 0 ? x_vec_dims[axes[i]] + ends[i] + : std::min(ends[i], x_vec_dims[axes[i]]); + offsets[axes[i]] = starts[i]; + slice_dims[axes[i]] = ends[i] - starts[i]; + } + + mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); + auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends, + x->format(), x_type); + + platform::ReorderMKLDNNHandler reorder_handler( + x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->format(), platform::to_void_cast(x->data())); + auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, + reorder_src_memory_p); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, slice_dims, 0, x->format(), ctx.GetPlace()); + + auto reorder_p = + reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat( + reorder_dst_memory_p->get_desc().reshape(out_vec_dims))); + } +}; + +template +class SliceGradMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("Input")); + + auto dx_vec_dims = framework::vectorize(dx->dims()); + auto dout_vec_dims = framework::vectorize(dout->dims()); + + auto axes_int = ctx.Attr>("axes"); + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + + std::vector axes(ctx.Attr>("axes").begin(), + ctx.Attr>("axes").end()); + std::vector starts(ctx.Attr>("starts").begin(), + ctx.Attr>("starts").end()); + std::vector ends(ctx.Attr>("ends").begin(), + ctx.Attr>("ends").end()); + + auto decrease_axis = ctx.Attr>("decrease_axis"); + + std::vector offsets(dx_vec_dims.size(), 0); + std::vector slice_dims(dx_vec_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + starts[i] = starts[i] < 0 ? dx_vec_dims[axes[i]] + starts[i] : starts[i]; + ends[i] = ends[i] < 0 ? dx_vec_dims[axes[i]] + ends[i] + : std::min(ends[i], dx_vec_dims[axes[i]]); + offsets[axes[i]] = starts[i]; + slice_dims[axes[i]] = ends[i] - starts[i]; + } + + mkldnn::memory::data_type dout_type = + framework::ToMKLDNNDataType(dout->type()); + mkldnn::memory::desc md(dout_vec_dims, platform::MKLDNNGetDataType(), + dout->format()); + mkldnn::memory::format_tag reorder_format_tag = + platform::GetMKLDNNFormat(md.reshape(slice_dims)); + + auto key = platform::CreateKey(dev_ctx, dout_vec_dims, axes, starts, ends, + reorder_format_tag, dout_type); + + platform::ReorderMKLDNNHandler reorder_handler( + slice_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + reorder_format_tag, platform::to_void_cast(dout->data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + dx, dx_vec_dims, 0, reorder_format_tag, ctx.GetPlace()); + memset(dx->data(), 0, reorder_dst_memory_p->get_desc().get_size()); + + auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets, + reorder_dst_memory_p); + + auto reorder_p = + reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); + astream.wait(); + + dx->set_layout(framework::DataLayout::kMKLDNN); + dx->set_format(reorder_format_tag); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace, + ops::SliceMKLDNNKernel, + ops::SliceMKLDNNKernel); + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(slice_grad, MKLDNN, paddle::platform::CPUPlace, + ops::SliceGradMKLDNNKernel, + ops::SliceGradMKLDNNKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc index afbe330305b7e..8a58d9f26f87b 100644 --- a/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/split_mkldnn_op.cc @@ -105,7 +105,7 @@ class SplitMKLDNNKernel : public framework::OpKernel { for (size_t i = 0; i < outs_number; ++i) { auto out_vec_dims = framework::vectorize(outs[i]->dims()); - auto slice_mem_p = reorder_handler.AcquireSrcSubmemory( + auto slice_mem_p = reorder_handler.AcquireSubmemory( out_vec_dims, offset, reorder_src_memory_p, i); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( diff --git a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc index 932368e810edd..5624312d9a728 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc @@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel { auto grad = ctx.Input("Grad"); auto& dev_ctx = ctx.template device_context(); - int r = xpu::momentum( - dev_ctx.x_context(), param->data(), velocity->data(), - grad->data(), lr, use_nesterov, mu, param_out->numel(), - param_out->data(), velocity_out->data()); + int r = xpu::momentum(dev_ctx.x_context(), param->data(), + velocity->data(), grad->data(), + param_out->data(), velocity_out->data(), + param_out->numel(), lr, use_nesterov, mu); if (r == xpu::Error_t::INVALID_PARAM) { PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, diff --git a/paddle/fluid/operators/partial_sum_op.cc b/paddle/fluid/operators/partial_sum_op.cc index 76e493836855e..fb60ab54b77b2 100644 --- a/paddle/fluid/operators/partial_sum_op.cc +++ b/paddle/fluid/operators/partial_sum_op.cc @@ -143,7 +143,8 @@ class PartialSumOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "use_mkldnn", "(bool, default false) Indicates if MKL-DNN kernel will be used") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr("start_index", "The start index of tensor wanted to be added.") .SetDefault(0); AddAttr("length", "The length of tensor wanted to be added.") diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index c9c1750b8569f..ae7e1c07b1496 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -279,7 +279,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Reshape Operator. diff --git a/paddle/fluid/operators/shard_index_op_npu.cc b/paddle/fluid/operators/shard_index_op_npu.cc new file mode 100644 index 0000000000000..83b5d12330d67 --- /dev/null +++ b/paddle/fluid/operators/shard_index_op_npu.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/shard_index_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; +template +class ShardIndexNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + VLOG(4) << "start kernel"; + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int index_num = context.Attr("index_num"); + int nshards = context.Attr("nshards"); + int shard_id = context.Attr("shard_id"); + int ignore_value = context.Attr("ignore_value"); + + PADDLE_ENFORCE_GT( + index_num, 0, + platform::errors::InvalidArgument( + "The value 'index_num' for Op(shard_index) must be greater than 0, " + "but the value given is %d.", + index_num)); + PADDLE_ENFORCE_GT(nshards, 0, + platform::errors::InvalidArgument( + "The value 'nshard' for Op(shard_index) must be " + "greater than 0, but the value given is %d.", + nshards)); + PADDLE_ENFORCE_GE( + shard_id, 0, + platform::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be greater or " + "equal to 0, but the value given is %d.", + shard_id)); + PADDLE_ENFORCE_LT( + shard_id, nshards, + platform::errors::InvalidArgument( + "The value 'shard_id' for Op(shard_index) must be less than " + "nshards (%d), but the value given is %d.", + nshards, shard_id)); + + int shard_size = (index_num + nshards - 1) / nshards; + + auto place = context.GetPlace(); + out->Resize(in->dims()); + out->set_lod(in->lod()); + out->mutable_data(place); + + Tensor tmp(in->type()); + tmp.mutable_data(framework::DDim({1}), place); + FillNpuTensorWithConstant(&tmp, shard_size); + + Tensor condition(framework::proto::VarType::BOOL); + condition.mutable_data(in->dims(), place); + + Tensor tmp2(in->type()); + tmp2.mutable_data(in->dims(), place); + + Tensor tmp3(in->type()); + tmp3.mutable_data(in->dims(), place); + + auto stream = + context.template device_context() + .stream(); + + NpuOpRunner runner; + runner.AddInputs({*in, tmp}); + runner.AddOutputs({tmp2}); + runner.SetType("Mod"); + runner.Run(stream); + + NpuOpRunner runner1; + runner1.AddInputs({*in, tmp}); + runner1.AddOutputs({tmp3}); + runner1.SetType("FloorDiv"); + runner1.Run(stream); + + FillNpuTensorWithConstant(&tmp, shard_id); + NpuOpRunner runner2; + runner2.AddInputs({tmp3, tmp}); + runner2.AddOutputs({condition}); + runner2.SetType("Equal"); + runner2.Run(stream); + + Tensor tmp4(in->type()); + tmp4.mutable_data(in->dims(), place); + FillNpuTensorWithConstant(&tmp4, ignore_value); + tmp4.Resize(in->dims()); + + NpuOpRunner runner3; + runner3.AddInputs({condition, tmp2, tmp4}); + runner3.AddOutputs({*out}); + runner3.SetType("Select"); + runner3.Run(stream); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL(shard_index, ops::ShardIndexNPUKernel, + ops::ShardIndexNPUKernel); diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index a55959385f627..ac50ccea9eee4 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -132,6 +132,26 @@ class SliceOp : public framework::OperatorWithKernel { if (platform::is_cuda_pinned_place(in_tensor.place())) { return framework::OpKernelType(in_tensor.type(), ctx.device_context()); } + +#ifdef PADDLE_WITH_MKLDNN + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // OneDNN uses blocking format, which cannot be always supported with + // reorders, because if blocked dimension is not divisible by 8 or + // 16(depending on which blocking format is used) submemory cannot be + // created, so in that scenario a fallback is needed + auto tmp_md = dnnl::memory::desc( + framework::vectorize(ctx.Input("Input")->dims()), + dnnl::memory::data_type::f32, ctx.Input("Input")->format()); + if (tmp_md.data.format_desc.blocking.inner_nblks == 0) + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(in_tensor.type(), in_tensor.place()); } return framework::OpKernelType( @@ -216,6 +236,14 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr>("decrease_axis", "(list) decrease_axis") .SetDefault({}); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC( Slice Operator. @@ -278,12 +306,32 @@ class SliceOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(x_grad_name, x_dims); } } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // OneDNN uses blocking format, which cannot be always supported with + // reorders, because if blocked dimension is not divisible by 8 or + // 16(depending on which blocking format is used) submemory cannot be + // created, so in that scenario a fallback is needed + auto tmp_md = dnnl::memory::desc( + framework::vectorize( + ctx.Input(framework::GradVarName("Out"))->dims()), + dnnl::memory::data_type::f32, + ctx.Input(framework::GradVarName("Out"))->format()); + if (tmp_md.data.format_desc.blocking.inner_nblks == 0) + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index f81ac8882d107..5bd699e08abbc 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -78,10 +78,9 @@ class SplitOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // OneDNN uses blocking format, which cannot be always - // supported with reorders, because if blocked dimension is not divisible - // by - // 8 or 16(depending on which blocking format is used) submemory cannot be + // OneDNN uses blocking format, which cannot be always supported with + // reorders, because if blocked dimension is not divisible by 8 or + // 16(depending on which blocking format is used) submemory cannot be // created, so in that scenario a fallback is needed auto tmp_md = dnnl::memory::desc( framework::vectorize(ctx.Input("X")->dims()), diff --git a/paddle/fluid/operators/strided_slice_op_npu.cc b/paddle/fluid/operators/strided_slice_op_npu.cc old mode 100755 new mode 100644 index deafdc5633a15..eb9377cc6381a --- a/paddle/fluid/operators/strided_slice_op_npu.cc +++ b/paddle/fluid/operators/strided_slice_op_npu.cc @@ -226,14 +226,204 @@ class StridedSliceNPUKernel : public framework::OpKernel { } }; +template +class StridedSliceGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Variable* input_var = ctx.InputVar("Input"); + bool is_tensor_array = input_var->IsType(); + PADDLE_ENFORCE_EQ(is_tensor_array, false, + platform::errors::InvalidArgument( + "Tensor array as input is not supported.")); + int rank = ctx.Input("Input")->dims().size(); + + switch (rank) { + case 1: + StridedSliceGradCompute<1>(ctx); + break; + case 2: + StridedSliceGradCompute<2>(ctx); + break; + case 3: + StridedSliceGradCompute<3>(ctx); + break; + case 4: + StridedSliceGradCompute<4>(ctx); + break; + case 5: + StridedSliceGradCompute<5>(ctx); + break; + case 6: + StridedSliceGradCompute<6>(ctx); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input is supported up to 6.")); + break; + } + } + + private: + template + void StridedSliceGradCompute(const framework::ExecutionContext& ctx) const { + auto place = ctx.GetPlace(); + auto& dev_ctx = + ctx.template device_context(); + + auto* input = ctx.Input("Input"); + auto input_dims = input->dims(); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("Input")); + dx->mutable_data(input_dims, place); + + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + auto strides_int = ctx.Attr>("strides"); + + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + std::vector strides(strides_int.begin(), strides_int.end()); + + auto axes = ctx.Attr>("axes"); + auto infer_flags = ctx.Attr>("infer_flags"); + auto decrease_axis = ctx.Attr>("decrease_axis"); + + auto list_new_ends_tensor = + ctx.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + ctx.MultiInput("StartsTensorList"); + auto list_new_strides_tensor = + ctx.MultiInput("StridesTensorList"); + + if (list_new_starts_tensor.size() > 0) { + starts = GetDataFromTensorList(list_new_starts_tensor); + } else if (ctx.HasInput("StartsTensor")) { + auto* starts_tensor = ctx.Input("StartsTensor"); + starts = GetDataFromTensor(starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = GetDataFromTensorList(list_new_ends_tensor); + } else if (ctx.HasInput("EndsTensor")) { + auto* ends_tensor = ctx.Input("EndsTensor"); + ends = GetDataFromTensor(ends_tensor); + } + + if (list_new_strides_tensor.size() > 0) { + strides = GetDataFromTensorList(list_new_strides_tensor); + } else if (ctx.HasInput("StridesTensor")) { + auto* strides_tensor = ctx.Input("StridesTensor"); + strides = GetDataFromTensor(strides_tensor); + } + + std::vector out_dims_vector(input_dims.size(), -1); + StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims, + decrease_axis, out_dims_vector.data(), axes.size(), + false); + + std::vector reverse_vector(starts.size(), 0); + StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), + reverse_vector.data(), input_dims, infer_flags, + decrease_axis, starts.size()); + + std::vector starts_indices_vector(D, 0); + std::vector ends_indices_vector(out_dims_vector.begin(), + out_dims_vector.end()); + std::vector strides_indices_vector(D, 1); + + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices_vector[axis_index] = starts[axis]; + ends_indices_vector[axis_index] = ends[axis]; + strides_indices_vector[axis_index] = strides[axis]; + } + + Tensor starts_indices_tensor; + Tensor ends_indices_tensor; + Tensor strides_indices_tensor; + + starts_indices_tensor.mutable_data({D}, place); + ends_indices_tensor.mutable_data({D}, place); + strides_indices_tensor.mutable_data({D}, place); + + TensorFromVector(starts_indices_vector, dev_ctx, &starts_indices_tensor); + TensorFromVector(ends_indices_vector, dev_ctx, &ends_indices_tensor); + TensorFromVector(strides_indices_vector, dev_ctx, &strides_indices_tensor); + + std::vector input_dims_vector; + for (int i = 0; i < input_dims.size(); i++) { + input_dims_vector.push_back(input_dims[i]); + } + Tensor input_dims_tensor; + TensorFromVector(input_dims_vector, dev_ctx, &input_dims_tensor); + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + auto stream = dev_ctx.stream(); + framework::NPUAttributeMap attr_input = {{"begin_mask", 0}, + {"end_mask", 0}, + {"ellipsis_mask", 0}, + {"new_axis_mask", 0}, + {"shrink_axis_mask", 0}}; + + if (need_reverse) { + Tensor reverse_axis; + std::vector reverse_axis_vector; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + reverse_axis_vector.push_back(axes[axis]); + } + } + reverse_axis.mutable_data( + {static_cast(reverse_axis_vector.size())}, place); + TensorFromVector(reverse_axis_vector, dev_ctx, &reverse_axis); + + Tensor dout_tmp; + dout_tmp.mutable_data(dout->dims(), place); + const auto& runner_reverse = + NpuOpRunner("ReverseV2", {*dout, reverse_axis}, {dout_tmp}); + runner_reverse.Run(stream); + + const auto& runner = + NpuOpRunner("StridedSliceGrad", + {input_dims_tensor, starts_indices_tensor, + ends_indices_tensor, strides_indices_tensor, dout_tmp}, + {*dx}, attr_input); + runner.Run(stream); + } else { + const auto& runner = + NpuOpRunner("StridedSliceGrad", + {input_dims_tensor, starts_indices_tensor, + ends_indices_tensor, strides_indices_tensor, *dout}, + {*dx}, attr_input); + runner.Run(stream); + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + strided_slice, ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel); + REGISTER_OP_NPU_KERNEL( - strided_slice, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel); + strided_slice_grad, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel); diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h new file mode 100644 index 0000000000000..aa6a369728839 --- /dev/null +++ b/paddle/fluid/operators/svd_helper.h @@ -0,0 +1,372 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/functors.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +using InTensors = std::vector; +using OutTensors = std::vector; +using OpName = std::string; + +template +void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, + int full = false) { + auto flag = Eigen::DecompositionOptions::ComputeThinU | + Eigen::DecompositionOptions::ComputeThinV; + if (full) { + flag = Eigen::DecompositionOptions::ComputeFullU | + Eigen::DecompositionOptions::ComputeFullV; + } + Eigen::BDCSVD< + Eigen::Matrix> + svd(2, 2, flag); + /*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/ + T* input = const_cast(X); + auto m = Eigen::Map< + Eigen::Matrix>( + input, rows, cols); + svd.compute(m); + Eigen::Matrix V_trans = + svd.matrixV().transpose(); + memcpy(U, svd.matrixU().data(), svd.matrixU().size() * sizeof(T)); + memcpy(VH, V_trans.data(), V_trans.size() * sizeof(T)); + memcpy(S, svd.singularValues().data(), + svd.singularValues().size() * sizeof(T)); +} + +template +void BatchSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, int batches, + int full = false) { + int stride = rows * cols; + int k = std::min(rows, cols); + int stride_u = full ? rows * rows : k * rows; + int stride_v = full ? cols * cols : k * cols; + for (int i = 0; i < batches; ++i) { + EigenSvd(X + i * stride, U + i * stride_u, VH + i * stride_v, S + i * k, + rows, cols, full); + } + return; +} + +template +struct PowFunctor { + PowFunctor(const T* input, T* output, int64_t numel, float exp) + : input_(input), output_(output), numel_(numel), exp_(exp) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = pow(input_[idx], exp_); + } + const T* input_; + T* output_; + int64_t numel_; + float exp_; +}; + +static std::vector GetBroadcastShape(InTensors ins) { + // TODO(xiongkun03) check the operators and output + PADDLE_ENFORCE_EQ(ins.size(), 2, platform::errors::InvalidArgument( + "GetBroadcastShape Receive 2 tensors" + "but got [%d]", + ins.size())); + auto x_dim = ins[0]->dims(); + auto y_dim = ins[1]->dims(); + std::vector broadcast_shape = + (x_dim.size() > y_dim.size() ? framework::vectorize(x_dim) + : framework::vectorize(y_dim)); + int rank_min = std::min(x_dim.size(), y_dim.size()); + int rank_x = x_dim.size(); + int rank_y = y_dim.size(); + int final_rank = broadcast_shape.size(); + for (int i = 1; i <= rank_min; ++i) { + if (x_dim[rank_x - i] == y_dim[rank_y - i]) { + broadcast_shape[final_rank - i] = x_dim[rank_x - i]; + continue; + } + if (x_dim[rank_x - i] == 1) { + broadcast_shape[final_rank - i] = y_dim[rank_y - i]; + continue; + } + if (y_dim[rank_y - i] == 1) { + broadcast_shape[final_rank - i] = x_dim[rank_x - i]; + continue; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "Wrong Input Shape in broadcast operator: " + "Input(X)'s shape must follow the broadcast rule with Input(Y)'s " + "shape, but received [%s] (X) vs [%s] (Y).", + x_dim, y_dim)); + } + return broadcast_shape; +} + +template +struct DeviceIndependenceTensorOperations { + // 1. Device indenpendence, for kernel reuse. + // 2. Input and output is always tensor type. + // 3. output Tensor is alway allocated + // 4. Basic Tensor operator is supported + // 5. The Reused Operator Kernel should only be considered as + // a wrap function + using NameInTensorMap = + std::map>; + using NameOutTensor = std::vector; + + explicit DeviceIndependenceTensorOperations( + const framework::ExecutionContext& context) + : context(context) {} + + framework::Tensor Pow(const framework::Tensor& x, float exp) { + framework::Tensor out; + auto for_range = GetForRange(x.numel()); + int numel = x.numel(); + PowFunctor functor(x.data(), out.mutable_data(x.dims(), x.place()), + numel, exp); + for_range(functor); + return out; + } + framework::Tensor Matmul(const framework::Tensor& mat_a, + const framework::Tensor& mat_b, bool trans_a = false, + bool trans_b = false) { + framework::AttributeMap attrs; + attrs["trans_x"] = trans_a; + attrs["trans_y"] = trans_b; + NameInTensorMap inputs({{"X", {&mat_a}}, {"Y", {&mat_b}}}); + auto a_dim = mat_a.dims(); + auto b_dim = mat_b.dims(); + std::vector x_vec = framework::vectorize(a_dim); + x_vec[x_vec.size() - 2] = a_dim[a_dim.size() - (trans_a ? 1 : 2)]; + x_vec[x_vec.size() - 1] = b_dim[b_dim.size() - (trans_b ? 2 : 1)]; + return CreateOpRunAndReturnTensor("matmul_v2", inputs, attrs, x_vec); + } + // transpose the last two dimision + framework::Tensor Transpose(const framework::Tensor& x) { + framework::Tensor out; + auto x_dim = x.dims(); + auto x_vec = framework::vectorize(x_dim); + int rank = x_vec.size(); + std::swap(x_vec[rank - 1], x_vec[rank - 2]); + std::vector out_shape = x_vec; + std::vector axis(rank); + for (int i = 0; i < rank; ++i) { + axis[i] = i; + } + std::swap(axis[rank - 1], axis[rank - 2]); + framework::AttributeMap attrs; + attrs["axis"] = axis; + NameInTensorMap inputs({{"X", {&x}}}); + return CreateOpRunAndReturnTensor("transpose2", inputs, attrs, out_shape, + {"Out", "XShape"}); + } + + framework::Tensor Diag(const framework::Tensor& x, int offset = 0, + int padding_value = 0) { + framework::AttributeMap attrs; + attrs["offset"] = offset; + attrs["padding_value"] = padding_value; + NameInTensorMap inputs({{"X", {&x}}}); + int x_rank = x.dims().size(); + std::vector out_shape; + if (x_rank == 2) { + PADDLE_ENFORCE_EQ(x.dims()[0], x.dims()[1], + platform::errors::InvalidArgument( + "if X is a Matrix, then X must be square")); + out_shape.push_back(x.dims()[0]); + } else if (x_rank == 1) { + out_shape.push_back(x.dims()[0]); + out_shape.push_back(x.dims()[0]); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Rank must less or equal than 2")); + } + return CreateOpRunAndReturnTensor("diag_v2", inputs, attrs, out_shape); + } + + framework::Tensor Add(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_add", inputs, attrs, + out_shape); + } + + framework::Tensor Mul(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_mul", inputs, attrs, + out_shape); + } + + framework::Tensor Sub(const framework::Tensor& x, + const framework::Tensor& y) { + InTensors ins({&x, &y}); + framework::AttributeMap attrs; + attrs["axis"] = -1; + std::vector out_shape = GetBroadcastShape({&x, &y}); + NameInTensorMap inputs({{"X", {&x}}, {"Y", {&y}}}); + return CreateOpRunAndReturnTensor("elementwise_sub", inputs, attrs, + out_shape); + } + + const framework::Tensor Unsqueeze(const framework::Tensor& x, int axis = 0) { + // don't copy data, only change the dims + framework::Tensor out; + out.ShareDataWith(x); + std::vector out_shape = framework::vectorize(x.dims()); + if (axis >= 0) { + auto index = (out_shape.begin() + axis); + out_shape.insert(index, 1); + } else if (axis < 0) { + auto index = (out_shape.end() + axis + 1); + out_shape.insert(index, 1); + } + out.Resize(framework::make_ddim(out_shape)); + return out; + } + + framework::Tensor Zeros(std::vector shape, + framework::proto::VarType::Type dtype, + float fill_value) { + framework::AttributeMap attrs; + attrs["dtype"] = dtype; + attrs["shape"] = shape; + attrs["value"] = fill_value; + NameInTensorMap inputs({}); + return CreateOpRunAndReturnTensor("fill_constant", inputs, attrs, shape); + } + + framework::Tensor Infinits(std::vector shape, + framework::proto::VarType::Type dtype) { + framework::AttributeMap attrs; + attrs["dtype"] = dtype; + attrs["shape"] = shape; + attrs["str_value"] = std::string("inf"); + NameInTensorMap inputs({}); + return CreateOpRunAndReturnTensor("fill_constant", inputs, attrs, shape); + } + + framework::Tensor Eye(int n, framework::proto::VarType::Type dtype) { + auto output = Zeros({n}, dtype, 1); + auto ret = Diag(output); + return ret; + } + + framework::Tensor Slice(const framework::Tensor& x, std::vector axes, + std::vector starts, std::vector ends) { + std::vector new_axes = axes; + NameInTensorMap inputs({{"Input", {&x}}}); + std::vector out_shape = framework::vectorize(x.dims()); + int rank = out_shape.size(); + PADDLE_ENFORCE_EQ( + axes.size(), starts.size(), + platform::errors::InvalidArgument("Slice Operator Argument Invalided")); + PADDLE_ENFORCE_EQ( + ends.size(), starts.size(), + platform::errors::InvalidArgument("Slice Operator Argument Invalided")); + for (unsigned int i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + if (axis < 0) axis = rank + axis; + new_axes[i] = axis; // change negative to positive + int st = starts[i]; + int ed = ends[i]; + PADDLE_ENFORCE_GT(ed, st, + platform::errors::InvalidArgument( + "C++ Slice Operation Not Support End < Start")); + out_shape[axis] = ed - st; + } + framework::AttributeMap attrs; + attrs["axes"] = new_axes; + attrs["starts"] = starts; + attrs["ends"] = ends; + return CreateOpRunAndReturnTensor("slice", inputs, attrs, out_shape); + } + + private: + const framework::ExecutionContext& context; + BlasT GetBlas() { + return math::GetBlas(context); + } + platform::ForRange GetForRange(int numel) { + auto& dev_ctx = context.template device_context(); + return platform::ForRange(dev_ctx, numel); + } + + framework::Tensor CreateOpRunAndReturnTensor( + const std::string& type, const NameInTensorMap& inputs, + const framework::AttributeMap& attrs, std::vector out_shape, + NameOutTensor out_str = {"Out"}) { + // varialble set dims must be LoDTensor / SelectedRowTensor + framework::Scope& local_scope = context.scope().NewScope(); + + framework::VariableNameMap op_outputs; + for (auto out_name : out_str) { + local_scope.Var("tmp_" + out_name)->GetMutable(); + op_outputs[out_name].emplace_back("tmp_" + out_name); + } + auto out_var = local_scope.Var("tmp_Out"); // return the Out + // create Out Tensor and allocat memory + out_var->GetMutable()->mutable_data( + framework::make_ddim(out_shape), context.GetPlace()); + // framework::make_ddim(out_shape) + framework::VariableNameMap op_inputs; + int counter = 0; + for (auto item : inputs) { + auto& tensors = item.second; + std::vector name_vector; + for (auto each_tensor : tensors) { + // create score variable and reset the tensor. + std::string _name = "tmp" + std::to_string(counter++); + auto in_var = local_scope.Var(_name); // create + framework::LoDTensor tmp_tns; + tmp_tns.ShareDataWith(*each_tensor); // tensor -> lodtensor + (*in_var->GetMutable()) = + tmp_tns; // initialize and set value + name_vector.emplace_back(_name); + } + op_inputs[item.first] = name_vector; + } + auto op = + framework::OpRegistry::CreateOp(type, op_inputs, op_outputs, attrs); + op->Run(local_scope, context.GetPlace()); + framework::Tensor out; + out.ShareDataWith(*(out_var->GetMutable())); + out.Resize(framework::make_ddim(out_shape)); + context.scope().DeleteScope(&local_scope); + return out; + } +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc new file mode 100644 index 0000000000000..90c138c578883 --- /dev/null +++ b/paddle/fluid/operators/svd_op.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/svd_op.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using DDim = framework::DDim; +static DDim UDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 1] = k; + return framework::make_ddim(x_vec); +} +static DDim VHDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + return framework::make_ddim(x_vec); +} +static DDim SDDim(const DDim& x_dim, int k) { + // get x_dim and return the ddim of U + auto x_vec = vectorize(x_dim); + x_vec[x_vec.size() - 2] = k; + x_vec.erase(x_vec.end() - 1); // rank - 1 + return framework::make_ddim(x_vec); +} + +class SvdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("U"), "Output", "U", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("VH"), "Output", "VH", "svd"); + OP_INOUT_CHECK(ctx->HasOutput("S"), "Output", "S", "svd"); + + auto in_dims = ctx->GetInputDim("X"); + int x_rank = in_dims.size(); + PADDLE_ENFORCE_GE(in_dims.size(), 2, + platform::errors::InvalidArgument( + "the rank of input must greater than 2")); + int m = in_dims[x_rank - 2]; + int n = in_dims[x_rank - 1]; + int k = std::min(m, n); + const bool full_uv = ctx->Attrs().Get("full_matrices"); + ctx->SetOutputDim("U", !full_uv ? UDDim(in_dims, k) : UDDim(in_dims, m)); + ctx->SetOutputDim("VH", !full_uv ? VHDDim(in_dims, k) : VHDDim(in_dims, n)); + ctx->SetOutputDim("S", SDDim(in_dims, k)); + + ctx->ShareLoD("X", /*->*/ "U"); + ctx->ShareLoD("X", /*->*/ "VH"); + ctx->ShareLoD("X", /*->*/ "S"); + } +}; + +class SvdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of svd op."); + AddOutput("U", "(Tensor), The output U tensor of svd op."); + AddOutput("S", "(Tensor), The output S tensor of svd op."); + AddOutput("VH", "(Tensor), The output VH tensor of svd op."); + AddAttr("full_matrices", + "(bool, default false) Only Compute the thin U and V" + "when set as True, the gradient have some random " + "attribute.") + .SetDefault(false); + AddComment(R"DOC( +Svd Operator. + +This operator is used to perform SVD operation for batched matrics $X$. +$$U, S, VH = svd(X)$$ + +)DOC"); + } +}; + +class SvdGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("U")), "Input", + "U@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("VH")), "Input", + "VH@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("S")), "Input", + "S@Grad", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("S"), "Input", "S", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasInput("VH"), "Input", "VH", "SvdGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "SvdGrad"); + + auto d_x = ctx->GetInputDim(("X")); + ctx->SetOutputDim(framework::GradVarName("X"), d_x); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class SvdGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("svd_grad"); + retv->SetInput(framework::GradVarName("U"), this->OutputGrad("U")); + retv->SetInput(framework::GradVarName("VH"), this->OutputGrad("VH")); + retv->SetInput(framework::GradVarName("S"), this->OutputGrad("S")); + retv->SetInput("U", this->Output("U")); + retv->SetInput("VH", this->Output("VH")); + retv->SetInput("S", this->Output("S")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(svd, ops::SvdOp, ops::SvdOpMaker, + ops::SvdGradMaker, + ops::SvdGradMaker); + +REGISTER_OPERATOR(svd_grad, ops::SvdGradOp); + +REGISTER_OP_CPU_KERNEL(svd, ops::SvdCPUKernel, + ops::SvdCPUKernel); + +REGISTER_OP_CPU_KERNEL( + svd_grad, ops::SvdGradKernel, + ops::SvdGradKernel); diff --git a/paddle/fluid/operators/svd_op.cu b/paddle/fluid/operators/svd_op.cu new file mode 100644 index 0000000000000..ade7496d64622 --- /dev/null +++ b/paddle/fluid/operators/svd_op.cu @@ -0,0 +1,175 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/svd_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace operators { + +template +class SvdGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = + context.template device_context(); + + const Tensor* x = context.Input("X"); + Tensor* U = context.Output("U"); + Tensor* VH = context.Output("VH"); + Tensor* S = context.Output("S"); + const bool full_matrices = context.Attr("full_matrices"); + + auto& dims = x->dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + int rank = dims.size(); + int m = dims[rank - 2]; + int n = dims[rank - 1]; + + auto* vh_data = VH->mutable_data(context.GetPlace()); + auto* s_data = S->mutable_data(context.GetPlace()); + auto* u_data = U->mutable_data(context.GetPlace()); + // NOTE:(@xiongkun03) + // matrices are assumed to be stored in column-major order in cusolver + // then view A as n x m and do A^T SVD, we can avoid transpose + // Must Copy X once, because the gesvdj will change the origin input matrix + Tensor x_tmp; + TensorCopy(*x, context.GetPlace(), &x_tmp); + auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count); + int* info_ptr = reinterpret_cast(info->ptr()); + + GesvdjBatched(dev_ctx, batch_count, n, m, std::min(m, n), + x_tmp.mutable_data(context.GetPlace()), vh_data, u_data, + s_data, info_ptr, !full_matrices); + + framework::DDim UT_dim = U->dims(); + std::swap(UT_dim[rank - 1], UT_dim[rank - 2]); // Get the dim of UT_dim + U->Resize(UT_dim); // U is entirely UT + auto dito = + math::DeviceIndependenceTensorOperations(context); + auto tmp_U = dito.Transpose(*U); + U->ShareDataWith(tmp_U); // U becomse UT, aka VT + } + void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize, + int m, int n, int k, T* A, T* U, T* V, T* S, int* info, + int thin_UV = 1) const; +}; + +template <> +void SvdGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, float* A, float* U, float* V, float* S, int* info, + int thin_UV) const { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + // check the error info + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +template <> +void SvdGPUKernel::GesvdjBatched( + const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n, + int k, double* A, double* U, double* V, double* S, int* info, + int thin_UV) const { + /* compute singular vectors */ + const cusolverEigMode_t jobz = + CUSOLVER_EIG_MODE_VECTOR; /* compute singular vectors */ + gesvdjInfo_t gesvdj_params = NULL; + int lda = m; + int ldu = m; + int ldt = n; + int lwork = 0; + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj_bufferSize( + handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork, + gesvdj_params)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + int stride_A = lda * n; + int stride_U = ldu * (thin_UV ? k : m); + int stride_V = ldt * (thin_UV ? k : n); + for (int i = 0; i < batchSize; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgesvdj( + handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i, + U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork, + info, gesvdj_params)); + // check the error info + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), info, + sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info)); + } + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(svd, ops::SvdGPUKernel, + ops::SvdGPUKernel); +REGISTER_OP_CUDA_KERNEL( + svd_grad, ops::SvdGradKernel, + ops::SvdGradKernel); +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/svd_op.h b/paddle/fluid/operators/svd_op.h new file mode 100644 index 0000000000000..1910effbeaa54 --- /dev/null +++ b/paddle/fluid/operators/svd_op.h @@ -0,0 +1,145 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class SvdCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* U = context.Output("U"); + Tensor* VH = context.Output("VH"); + Tensor* S = context.Output("S"); + int full = context.Attr("full_matrices"); + + /*Create Tensors and output, set the dim ...*/ + auto numel = x->numel(); + auto* x_data = x->data(); + auto x_dims = x->dims(); + int rows = x_dims[x_dims.size() - 2]; + int cols = x_dims[x_dims.size() - 1]; + int k = std::min(rows, cols); + int col_u = full ? rows : k; + int col_v = full ? cols : k; + int batches = numel / (rows * cols); + auto* U_out = U->mutable_data>( + context.GetPlace(), + size_t(batches * rows * col_u * sizeof(math::Real))); + auto* VH_out = VH->mutable_data>( + context.GetPlace(), + size_t(batches * col_v * cols * sizeof(math::Real))); + auto* S_out = S->mutable_data>( + context.GetPlace(), size_t(batches * k * sizeof(math::Real))); + + /*SVD Use the Eigen Library*/ + math::BatchSvd(x_data, U_out, VH_out, S_out, rows, cols, batches, full); + } +}; + +template +class SvdGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor& U_const = *ctx.Input("U"); + const framework::Tensor& VH_const = *ctx.Input("VH"); + const framework::Tensor& S = *ctx.Input("S"); + framework::Tensor& dX = + *ctx.Output(framework::GradVarName("X")); + const framework::Tensor& dU_const = + *ctx.Input(framework::GradVarName("U")); + const framework::Tensor& dVH_const = + *ctx.Input(framework::GradVarName("VH")); + + const bool full = ctx.Attr("full_matrices"); + int m = dX.dims()[dX.dims().size() - 2]; + int n = dX.dims()[dX.dims().size() - 1]; + int k = S.dims()[S.dims().size() - 1]; + auto dito = math::DeviceIndependenceTensorOperations(ctx); + framework::Tensor U, VH, dU, dV, dVH; + if (full) { + // if full_matrices is set, slice the U and VT to k columns + U = dito.Slice(U_const, {-1}, {0}, {k}); + VH = dito.Slice(VH_const, {-2}, {0}, {k}); + dU = dito.Slice(dU_const, {-1}, {0}, {k}); + dVH = dito.Slice(dVH_const, {-2}, {0}, {k}); + } else { + U = U_const; + VH = VH_const; + dU = dU_const; + dVH = dVH_const; + } + auto s_inverse = dito.Pow(S, -1); + auto s_square = dito.Pow(S, 2); + auto F = + dito.Sub(dito.Unsqueeze(s_square, -2), dito.Unsqueeze(s_square, -1)); + F = dito.Add(F, dito.Diag(dito.Infinits({k}, U.type()))); + F = dito.Pow(F, -1); + Tensor sigma_term; + Tensor u_term; + Tensor v_term; + + if (ctx.HasInput(framework::GradVarName("S"))) { + const framework::Tensor& gS = + *ctx.Input(framework::GradVarName("S")); + sigma_term = dito.Mul(dito.Unsqueeze(gS, -2), U); + sigma_term = dito.Matmul(sigma_term, VH); + } + + if (ctx.HasInput(framework::GradVarName("U"))) { + auto UTG = dito.Matmul(U, dU, true, false); + auto GTU = dito.Matmul(dU, U, true, false); + u_term = dito.Mul(dito.Mul(dito.Sub(UTG, GTU), F), dito.Unsqueeze(S, -2)); + u_term = dito.Matmul(U, u_term); + if (m > k) { + auto project = + dito.Sub(dito.Eye(m, U.type()), dito.Matmul(U, U, false, true)); + u_term = dito.Add(u_term, dito.Mul(dito.Matmul(project, dU), + dito.Unsqueeze(s_inverse, -2))); + } + u_term = dito.Matmul(u_term, VH); + } + + if (ctx.HasInput(framework::GradVarName("VH"))) { + auto UTG = dito.Matmul(VH, dVH, false, true); + auto GTU = dito.Matmul(dVH, VH, false, true); + v_term = dito.Mul(dito.Matmul(dito.Mul(dito.Sub(UTG, GTU), F), VH), + dito.Unsqueeze(S, -1)); + if (n > k) { + auto project = + dito.Sub(dito.Eye(n, U.type()), dito.Matmul(VH, VH, true, false)); + v_term = dito.Add(v_term, dito.Mul(dito.Matmul(dVH, project), + dito.Unsqueeze(s_inverse, -1))); + } + v_term = dito.Matmul(U, v_term); + } + + dX.ShareDataWith(dito.Add(dito.Add(u_term, sigma_term), v_term)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tensor_formatter.cc b/paddle/fluid/operators/tensor_formatter.cc index f1b64f042c3c0..a0cda54b31b4c 100644 --- a/paddle/fluid/operators/tensor_formatter.cc +++ b/paddle/fluid/operators/tensor_formatter.cc @@ -119,10 +119,10 @@ void TensorFormatter::FormatData(const framework::LoDTensor& print_tensor, ? print_tensor.numel() : std::min(summarize_, print_tensor.numel()); const T* data = nullptr; + framework::LoDTensor cpu_tensor; if (is_cpu_place(print_tensor.place())) { data = print_tensor.data(); } else { - framework::LoDTensor cpu_tensor; platform::CPUPlace cpu_place; TensorCopy(print_tensor, cpu_place, &cpu_tensor); #ifdef PADDLE_WITH_ASCEND_CL diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index da7f824b3a6f1..18ee5d71541e0 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -119,14 +119,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { "tensor's axes according to the values given."); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr( "data_format", "(string, default NCHW) Only used in " "An optional string from: \"NHWC\", \"NCHW\". " "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") - .SetDefault("AnyLayout"); + .SetDefault("AnyLayout") + .AsExtra(); AddAttr( "use_quantizer", "(bool, default false) " @@ -262,7 +264,9 @@ class Transpose2OpMaker : public TransposeOpMaker { public: void Make() override { TransposeOpMaker::Make(); - AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate(); + AddOutput("XShape", "(Tensor)The output tensor.") + .AsIntermediate() + .AsExtra(); } }; diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu index a462bbb4834ac..383fc6a5b9b32 100644 --- a/paddle/fluid/operators/transpose_op.cu +++ b/paddle/fluid/operators/transpose_op.cu @@ -12,650 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include - -#include "paddle/fluid/framework/gpu_utils.h" +#include "paddle/fluid/operators/transpose_op.cu.h" #include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using Dim3 = framework::Dim3; -using Index3 = framework::Index3; - -struct EqualTo { - constexpr bool operator()(int a, int b) const { return a == b; } -}; - -struct GreaterThan { - constexpr bool operator()(int a, int b) const { return a > b; } -}; - -// Value can be decided in compile time. -template -constexpr bool CheckProperTileSize(int tile_long, int tile_short, int size_T, - FUN op) { - return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) || - (tile_long == 2 * INT_32 && op(tile_short, 4)) || - (tile_long == 4 * INT_32 && op(tile_short, 4)) || - (tile_long == 8 * INT_32 && op(tile_short, 2)))) || - (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) || - (tile_long == 2 * INT_32 && op(tile_short, 15)) || - (tile_long == 4 * INT_32 && op(tile_short, 8)) || - (tile_long == 8 * INT_32 && op(tile_short, 4)) || - (tile_long == 16 * INT_32 && op(tile_short, 2)))) || - ((size_T == 4 || size_T == 2 || size_T == 1) && - ((tile_long == INT_32 && op(tile_short, 15)) || - (tile_long == 2 * INT_32 && op(tile_short, 15)) || - (tile_long == 4 * INT_32 && op(tile_short, 8)) || - (tile_long == 8 * INT_32 && op(tile_short, 4)) || - (tile_long == 16 * INT_32 && op(tile_short, 2)) || - (tile_long == 16 * INT_32 && op(tile_short, 2)))); -} - -constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) { - return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo()); -} - -constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) { - return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan()); -} - -constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) { - return !CheckOutsideTileSize(tile_long, tile_short, size_T) && - (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) || - CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) && - !CheckLongTileSize(tile_long, tile_short, size_T); -} - -// Use SM to do data transfer, load a tile into SM then store out. -// All tile read and write are colascing, so can speedup memory copy -template -__global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, - T* __restrict__ output) { - assert(blockDim.x == NumThreads); - assert(blockDim.y == 1); - assert(blockDim.z == 1); - assert(gridDim.y == 1); - assert(gridDim.z == 1); - - constexpr int BlockReadRows = NumThreads / TileY; - constexpr int BlockWriteRows = NumThreads / TileX; - - // One extra line in the inner dimension to avoid share memory bank conflict. - __shared__ __align__( - alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)]; - typedef T(*ShareMemory)[TileY + 1]; - - ShareMemory tile_sm = reinterpret_cast(share_mem_ptr); - - int x = threadIdx.x; - - Dim3 output_dims = { - input_dims[0], input_dims[2], input_dims[1], - }; - - // Align dim to Tiles - Dim3 tile_aligned_input_dim = { - input_dims[0], (input_dims[1] + TileX - 1) / TileX, - (input_dims[2] + TileY - 1) / TileY, - }; - - // Converts block idx to tile index, each block process a tile - Index3 input_block_tile_index = - ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); - - // Compute real index align to tile:0, 32, 64... - Index3 block_tile_index_in_input = { - input_block_tile_index[0], input_block_tile_index[1] * TileX, - input_block_tile_index[2] * TileY, - }; - - // Compute block flat index against input dims. - int input_origin_block_flat_index = - FlatTensorIndex(block_tile_index_in_input, input_dims); - - bool full_tile = true; - int tile_width = TileY; - - // Last row is not full. - if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) { - tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY; - full_tile &= false; - } - - int tile_height = TileX; - - if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) { - tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX; - full_tile &= false; - } - - constexpr int in_effective_thread_num = NumThreads / TileY * TileY; - - if (x < in_effective_thread_num) { - // Read a tile from input using block. - int x_i = x / TileY; - int x_j = x % TileY; - int input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; - int input_inc = BlockReadRows * input_dims[2]; - - if (full_tile) { -#pragma unroll - for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) { - tile_sm[ind_i][x_j] = input[input_ind]; - input_ind += input_inc; - } - } else { - if (x_j < tile_width) { -#pragma unroll - for (int ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { - tile_sm[ind_i][x_j] = input[input_ind]; - input_ind += input_inc; - } - } - } - } - - __syncthreads(); - - // Store sm value back to out - Index3 output_block_tile_index = { - input_block_tile_index[0], input_block_tile_index[2], - input_block_tile_index[1], - }; - - Index3 block_tile_index_in_output = { - output_block_tile_index[0], output_block_tile_index[1] * TileY, - output_block_tile_index[2] * TileX, - }; - - int output_origin_block_flat_index = - FlatTensorIndex(block_tile_index_in_output, output_dims); - - constexpr int out_effective_thread_num = NumThreads / TileX * TileX; - - if (x < out_effective_thread_num) { - int x_i = x / TileX; - int x_j = x % TileX; - int output_ind = - output_origin_block_flat_index + x_i * output_dims[2] + x_j; - int output_inc = BlockWriteRows * output_dims[2]; - - if (full_tile) { -#pragma unroll - for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) { - output[output_ind] = tile_sm[x_j][ind_i]; - output_ind += output_inc; - } - } else { - if (x_j < tile_height) { -#pragma unroll - for (int ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { - output[output_ind] = tile_sm[x_j][ind_i]; - output_ind += output_inc; - } - } - } - } -} - -// This function will find combination of long_side X short_side in backups -template -bool SelectProperTileSize(std::vector>* tiles) { - PADDLE_ENFORCE_LE( - TSIZE, 16, - platform::errors::InvalidArgument( - "The tile size should smaller than 16, but received is:%d.", TSIZE)); - - PADDLE_ENFORCE_EQ( - (TSIZE & (TSIZE - 1)), 0, - platform::errors::InvalidArgument( - "Data types should be powers of 2, but reived size is:%d.", TSIZE)); - - const int kMaxLongSideLen = 1024; - const int kMaxShortSideLen = 15; - - for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { - for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) { - if (CheckLongTileSize(long_side, short_side, TSIZE)) { - tiles->push_back(std::make_pair(long_side, short_side)); - - if (short_side == 2) return true; - - break; - } - } - } - return false; -} - -// Use system built in type -template -struct SystemElemType; -template <> -struct SystemElemType<1> { - using type = uint8_t; -}; -template <> -struct SystemElemType<2> { - using type = uint16_t; -}; -template <> -struct SystemElemType<4> { - using type = uint32_t; -}; -template <> -struct SystemElemType<8> { - using type = uint64_t; -}; -template <> -struct SystemElemType<16> { - using type = float4; -}; - -template -void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d, - int tile_size_i, int tile_size_j, - int total_tiles_count, const T* input, - const Dim3& input_dims, T* output) { - constexpr int NumThreads = tile_long; - if (tile_size_i <= tile_long && tile_size_j <= tile_short) { - TilingSwapDim1And2< - T, NumThreads, tile_long, - tile_short><<>>( - input, input_dims, output); - } else { - TilingSwapDim1And2< - T, NumThreads, tile_short, - tile_long><<>>( - input, input_dims, output); - } -} - -template -struct NarrowDims2TransposeDispatch { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { - PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( - "The length of the longer side of the tile should be power of 2." - " But received value is:%d.", - tile_long)); - - bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && - std::min(tile_size_i, tile_size_j) <= tile_short; - - if (request_satisfied) { - LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - return; - } - - const bool long_side_request_not_satisfied = - std::max(tile_size_i, tile_size_j) > tile_long; - - if (long_side_request_not_satisfied) { - NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - } else { - NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - } - } -}; - -// If Not long tile size, goto this function when compile. -template -struct NarrowDims2TransposeDispatch< - T, tile_long, tile_short, - typename std::enable_if< - CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { - PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( - "The length of the longer side of the tile should be power of 2." - " But received value is:%d.", - tile_long)); - - bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && - std::min(tile_size_i, tile_size_j) <= tile_short; - - if (request_satisfied) { - LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - return; - } - - NarrowDims2TransposeDispatch::DoTranspose( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - } -}; - -// If long tile size, goto this function when compile. -template -struct NarrowDims2TransposeDispatch< - T, tile_long, tile_short, - typename std::enable_if::type> { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, - int tile_size_j, int total_tiles_count, - const T* input, const Dim3& input_dims, T* output) { - PADDLE_ENFORCE_EQ( - (tile_long & (tile_long - 1)), 0, - platform::errors::InvalidArgument( - "The length of the longer side of the tile should be power of 2," - " but received is:%d.", - tile_long)); - - LaunchNarrowDims2TransposeKernel( - d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, - output); - } -}; - -template -void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input, - const Dim3& input_dims, T* output, - const int kMinTileSize) { - // First get available tile sizes for the data type requested as backups - std::vector> tile_sele; - auto ret = SelectProperTileSize(&tile_sele); - PADDLE_ENFORCE_EQ( - ret, true, - platform::errors::InvalidArgument( - "SelectProperTileSize should return true, but return value is:%d.", - ret)); - - int tile_long_edge = 0; - int tile_short_edge = 0; - float lowest_cost = std::numeric_limits::max(); - int input_long_edge = std::max(input_dims[1], input_dims[2]); - - // Find the tile size that best suit in inputs. - for (auto tile_size_pair : tile_sele) { - int proposed_tile_long_edge = tile_size_pair.first; - // data may not aligned to tile, so some threads wasted, we need - // to find least wasted threads, which means we need to find tile - // can split input properly, in another words: num_wasted_threads=0. - int num_wasted_threads = input_long_edge - - framework::CeilOrFloor( - input_long_edge, proposed_tile_long_edge) * - proposed_tile_long_edge; - - int num_full_tiles = framework::CeilOrFloor( - input_long_edge, proposed_tile_long_edge); - - float cost = num_wasted_threads; - - if (cost <= lowest_cost) { - tile_long_edge = proposed_tile_long_edge; - tile_short_edge = tile_size_pair.second; - lowest_cost = cost; - } - // break as we already find best tile size. - if (cost == 0) break; - } - - // The tile size we select should be match with input dim, long side to long - // short side to short. - // First set long side as i if dim1 > Tile min size, then set dim2 as j. - int select_tile_size_i = - input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1]; - int select_tile_size_j = - input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge; - - // Check if i is long edge, if not set i as short. - select_tile_size_i = select_tile_size_i == tile_long_edge - ? tile_long_edge - : std::min(select_tile_size_i, tile_short_edge); - - // Check if j is long edge, if not set j as short. - select_tile_size_j = select_tile_size_j == tile_long_edge - ? tile_long_edge - : std::min(select_tile_size_j, tile_short_edge); - - // Here finally get proper long X short tile size. - Dim3 input_dims_aligned = { - input_dims[0], - framework::CeilOrFloor(input_dims[1], select_tile_size_i), - framework::CeilOrFloor(input_dims[2], select_tile_size_j), - }; - - int total_tiles_count = - input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; - - // Suppose T can be replaced by system builtin types - using ElemType = typename SystemElemType::type; - - NarrowDims2TransposeDispatch::DoTranspose( - d, select_tile_size_i, select_tile_size_j, total_tiles_count, - reinterpret_cast(input), input_dims, - reinterpret_cast(output)); -} - -// This is for case that cannot do coalescing read and write. -// Or input is too small to split into tiles. -template -__global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, - Dim3 input_dims, T* __restrict__ output) { - Dim3 output_dims; - output_dims[pos0] = input_dims[0]; - output_dims[pos1] = input_dims[1]; - output_dims[pos2] = input_dims[2]; - - CUDA_KERNEL_LOOP(output_index, nthreads) { - Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims); - - Index3 input_tensor_index; - input_tensor_index[0] = output_tensor_index[pos0]; - input_tensor_index[1] = output_tensor_index[pos1]; - input_tensor_index[2] = output_tensor_index[pos2]; - - int input_index = FlatTensorIndex(input_tensor_index, input_dims); - - output[output_index] = input[input_index]; - } -} - -// Here suppose convert all tensor to dim3, so just change dim1 and 2. -template -void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, - const T* input, const Dim3& input_dims, - T* output) { - // Suppose tile size > 16 - static const int kMinTileSize = 16; - static const int kMinNarrowTileSize = 96; - - bool large_tile = - input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize; - bool narrow_tile = input_dims[1] >= kMinNarrowTileSize || - input_dims[2] >= kMinNarrowTileSize; - if (large_tile) { - // If input is large square, such as 32X32, use SM to do copy. - // suppose 32 X 32 gives best performance, and 8 warp in block. - constexpr int kTileSize = 32; - constexpr int kNumThreads = 256; - - Dim3 input_dims_aligned = { - input_dims[0], - framework::CeilOrFloor(input_dims[1], kTileSize), - framework::CeilOrFloor(input_dims[2], kTileSize), - }; - - int total_tiles_count = - input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; - - TilingSwapDim1And2< - T, kNumThreads, kTileSize, - kTileSize><<>>( - input, input_dims, output); - - } else if (narrow_tile) { - // If input shape is like Rect, such as 2X100, use Narrow tile size. - // It makes things complicated, because need to find a tile can coverr - // input and also reach best coalescing. - SwapDim1And2InNarrow(d, input, input_dims, output, kMinTileSize); - } else { - // If input shape is small, such as 8X8, just do simple copy - int total_elements = input_dims[0] * input_dims[1] * input_dims[2]; - auto config = GetGpuLaunchConfig1D(d, total_elements); - TransposeSimpleKernel<<< - config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( - total_elements, input, input_dims, output); - } -} - -template -struct SwapDim1And2InTranspose { - typedef platform::CUDADeviceContext Device; - void operator()(const Device& d, const T* in, - const std::vector& combined_dims, T* out) { - Dim3 input_dims = {static_cast(combined_dims[0]), - static_cast(combined_dims[1]), - static_cast(combined_dims[2])}; - SendSwapDim1And2InTranspose(d, in, input_dims, out); - } -}; - -template -struct SwapDim0And2InTranspose { - typedef platform::CUDADeviceContext Device; - void operator()(const Device& d, const T* in, - const std::vector& combined_dims, T* out) { - Dim3 input_dims = {static_cast(combined_dims[0]), - static_cast(combined_dims[1]), - static_cast(combined_dims[2])}; - - size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; - auto config = GetGpuLaunchConfig1D(d, total_size); - - TransposeSimpleKernel<<< - config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( - total_size, in, input_dims, out); - } -}; - -// This function is to combine dimension. fox example: -// (0, 1, 3, 2) --> (0, 2, 1) -inline void CombineTransposeDim3(const framework::DDim& shape, - const std::vector& perm, - std::vector* new_perm, - framework::DDim* new_dims) { - PADDLE_ENFORCE_EQ(shape.size(), perm.size(), - platform::errors::InvalidArgument( - " shape should have the save dim with perm, but" - " received shape size is:%d, perm size is:%d.", - shape.size(), perm.size())); - - std::vector dim_vec; - if (shape.size() == 1) { - // If input dimension is already 1, no need to combine dim. - new_perm->resize(1); - (*new_perm)[0] = perm[0]; - dim_vec.push_back(shape[0]); - *new_dims = framework::make_ddim(dim_vec); - return; - } - std::vector new_dim_pos(shape.size(), -1); - std::vector combined_dims(shape.size(), 0); - int cur_head = perm[0]; - new_dim_pos[cur_head] = 0; - combined_dims[0] = shape[cur_head]; - int dim_idx = 0; - for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { - // combine consecutive dimensions. - if (cur_head + 1 == perm[perm_idx]) { - cur_head = perm[perm_idx]; - combined_dims[dim_idx] *= shape[cur_head]; - } else { - // Else start a new dimension. - cur_head = perm[perm_idx]; - dim_idx++; - new_dim_pos[cur_head] = dim_idx; - combined_dims[dim_idx] = shape[cur_head]; - } - } - - new_perm->resize(dim_idx + 1); - - dim_idx = 0; - for (int i = 0; i < new_dim_pos.size(); ++i) { - if (new_dim_pos[i] >= 0) { - int new_perm_idx = new_dim_pos[i]; - (*new_perm)[dim_idx] = new_perm_idx; - dim_vec.push_back(combined_dims[new_perm_idx]); - dim_idx++; - } - } - - *new_dims = framework::make_ddim(dim_vec); -} - -template -struct TransposeSimple { - static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in, - const std::vector perm, Tensor* out) { - // First reduce the dimensions of the input tensor if possible. - std::vector new_perm; - framework::DDim new_dims; - CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); - - // Only use tile copy GPU kernel when dimension is 2 or 3. - int dims = new_dims.size(); - std::vector new_dim_vec = framework::vectorize(new_dims); - if (dims < 2 || dims > 3) return false; - auto in_data = in.data(); - auto out_data = out->data(); - // In most cases, dim will not greater than 3 after combine. - switch (dims) { - case 2: - if (new_perm[0] == 1 && new_perm[1] == 0) { - // Add the first dimension size as 1. - new_dim_vec.insert(new_dim_vec.begin(), 1); - SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); - return true; - } - break; - case 3: - // In this case, suppose we can do coalescing read and write in tile. - if (new_perm == std::vector({0, 2, 1})) { - SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); - return true; - } else if (new_perm == std::vector({2, 1, 0})) { - // Maybe can optimize later, find a way to do coalescing memory copy. - // But I think it depends on the data size. If span is not large, - // maybe - // can do coalescing. - SwapDim0And2InTranspose()(ctx, in_data, new_dim_vec, out_data); - return true; - } else { - return false; - } - break; - default: - return false; - } - return false; - } -}; - template class TransposeGPUKernel : public framework::OpKernel { public: @@ -676,11 +39,7 @@ class TransposeGPUKernel : public framework::OpKernel { std::vector axis = context.Attr>("axis"); int ndims = axis.size(); const auto& dev_ctx = context.template device_context(); - auto ret = TransposeSimple::run(dev_ctx, *x_tensor, axis, out_tensor); - if (!ret) { - TransCompute(ndims, dev_ctx, *x_tensor, out_tensor, - axis); - } + TransposeGPUKernelDriver(dev_ctx, ndims, *x_tensor, axis, out_tensor); } }; template @@ -711,12 +70,8 @@ class TransposeGradGPUKernel : public framework::OpKernel { int ndims = axis.size(); const auto& dev_ctx = context.template device_context(); - auto ret = TransposeSimple::run(dev_ctx, *out_grad_tensor, reversed_axis, - x_grad_tensor); - if (!ret) { - TransCompute(ndims, dev_ctx, *out_grad_tensor, - x_grad_tensor, reversed_axis); - } + TransposeGPUKernelDriver(dev_ctx, ndims, *out_grad_tensor, reversed_axis, + x_grad_tensor); } }; diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h new file mode 100644 index 0000000000000..784d97b543fbd --- /dev/null +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -0,0 +1,667 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/gpu_utils.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using Dim3 = framework::Dim3; +using Index3 = framework::Index3; + +struct EqualTo { + constexpr bool operator()(int a, int b) const { return a == b; } +}; + +struct GreaterThan { + constexpr bool operator()(int a, int b) const { return a > b; } +}; + +// Value can be decided in compile time. +template +constexpr bool CheckProperTileSize(int tile_long, int tile_short, int size_T, + FUN op) { + return (size_T == 16 && ((tile_long == INT_32 && op(tile_short, 4)) || + (tile_long == 2 * INT_32 && op(tile_short, 4)) || + (tile_long == 4 * INT_32 && op(tile_short, 4)) || + (tile_long == 8 * INT_32 && op(tile_short, 2)))) || + (size_T == 8 && ((tile_long == INT_32 && op(tile_short, 15)) || + (tile_long == 2 * INT_32 && op(tile_short, 15)) || + (tile_long == 4 * INT_32 && op(tile_short, 8)) || + (tile_long == 8 * INT_32 && op(tile_short, 4)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)))) || + ((size_T == 4 || size_T == 2 || size_T == 1) && + ((tile_long == INT_32 && op(tile_short, 15)) || + (tile_long == 2 * INT_32 && op(tile_short, 15)) || + (tile_long == 4 * INT_32 && op(tile_short, 8)) || + (tile_long == 8 * INT_32 && op(tile_short, 4)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)) || + (tile_long == 16 * INT_32 && op(tile_short, 2)))); +} + +constexpr bool CheckLongTileSize(int tile_long, int tile_short, int size_T) { + return CheckProperTileSize(tile_long, tile_short, size_T, EqualTo()); +} + +constexpr bool CheckOutsideTileSize(int tile_long, int tile_short, int size_T) { + return CheckProperTileSize(tile_long, tile_short, size_T, GreaterThan()); +} + +constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) { + return !CheckOutsideTileSize(tile_long, tile_short, size_T) && + (CheckOutsideTileSize(tile_long * 2, tile_short, size_T) || + CheckOutsideTileSize(tile_long, tile_short + 1, size_T)) && + !CheckLongTileSize(tile_long, tile_short, size_T); +} + +// Use SM to do data transfer, load a tile into SM then store out. +// All tile read and write are colascing, so can speedup memory copy +template +__global__ void TilingSwapDim1And2(const T* __restrict__ input, Dim3 input_dims, + T* __restrict__ output) { + assert(blockDim.x == NumThreads); + assert(blockDim.y == 1); + assert(blockDim.z == 1); + assert(gridDim.y == 1); + assert(gridDim.z == 1); + + constexpr int BlockReadRows = NumThreads / TileY; + constexpr int BlockWriteRows = NumThreads / TileX; + + // One extra line in the inner dimension to avoid share memory bank conflict. + __shared__ __align__( + alignof(T)) char share_mem_ptr[TileX * (TileY + 1) * sizeof(T)]; + typedef T(*ShareMemory)[TileY + 1]; + + ShareMemory tile_sm = reinterpret_cast(share_mem_ptr); + + int x = threadIdx.x; + + Dim3 output_dims = { + input_dims[0], input_dims[2], input_dims[1], + }; + + // Align dim to Tiles + Dim3 tile_aligned_input_dim = { + input_dims[0], (input_dims[1] + TileX - 1) / TileX, + (input_dims[2] + TileY - 1) / TileY, + }; + + // Converts block idx to tile index, each block process a tile + Index3 input_block_tile_index = + ConvertTensorIndex(blockIdx.x, tile_aligned_input_dim); + + // Compute real index align to tile:0, 32, 64... + Index3 block_tile_index_in_input = { + input_block_tile_index[0], input_block_tile_index[1] * TileX, + input_block_tile_index[2] * TileY, + }; + + // Compute block flat index against input dims. + int input_origin_block_flat_index = + FlatTensorIndex(block_tile_index_in_input, input_dims); + + bool full_tile = true; + int tile_width = TileY; + + // Last row is not full. + if (input_block_tile_index[2] == tile_aligned_input_dim[2] - 1) { + tile_width = input_dims[2] - (tile_aligned_input_dim[2] - 1) * TileY; + full_tile &= false; + } + + int tile_height = TileX; + + if (input_block_tile_index[1] == tile_aligned_input_dim[1] - 1) { + tile_height = input_dims[1] - (tile_aligned_input_dim[1] - 1) * TileX; + full_tile &= false; + } + + constexpr int in_effective_thread_num = NumThreads / TileY * TileY; + + if (x < in_effective_thread_num) { + // Read a tile from input using block. + int x_i = x / TileY; + int x_j = x % TileY; + int input_ind = input_origin_block_flat_index + x_i * input_dims[2] + x_j; + int input_inc = BlockReadRows * input_dims[2]; + + if (full_tile) { +#pragma unroll + for (int ind_i = x_i; ind_i < (TileX); ind_i += BlockReadRows) { + tile_sm[ind_i][x_j] = input[input_ind]; + input_ind += input_inc; + } + } else { + if (x_j < tile_width) { +#pragma unroll + for (int ind_i = x_i; ind_i < (tile_height); ind_i += BlockReadRows) { + tile_sm[ind_i][x_j] = input[input_ind]; + input_ind += input_inc; + } + } + } + } + + __syncthreads(); + + // Store sm value back to out + Index3 output_block_tile_index = { + input_block_tile_index[0], input_block_tile_index[2], + input_block_tile_index[1], + }; + + Index3 block_tile_index_in_output = { + output_block_tile_index[0], output_block_tile_index[1] * TileY, + output_block_tile_index[2] * TileX, + }; + + int output_origin_block_flat_index = + FlatTensorIndex(block_tile_index_in_output, output_dims); + + constexpr int out_effective_thread_num = NumThreads / TileX * TileX; + + if (x < out_effective_thread_num) { + int x_i = x / TileX; + int x_j = x % TileX; + int output_ind = + output_origin_block_flat_index + x_i * output_dims[2] + x_j; + int output_inc = BlockWriteRows * output_dims[2]; + + if (full_tile) { +#pragma unroll + for (int ind_i = x_i; ind_i < (TileY); ind_i += BlockWriteRows) { + output[output_ind] = tile_sm[x_j][ind_i]; + output_ind += output_inc; + } + } else { + if (x_j < tile_height) { +#pragma unroll + for (int ind_i = x_i; ind_i < (tile_width); ind_i += BlockWriteRows) { + output[output_ind] = tile_sm[x_j][ind_i]; + output_ind += output_inc; + } + } + } + } +} + +// This function will find combination of long_side X short_side in backups +template +bool SelectProperTileSize(std::vector>* tiles) { + PADDLE_ENFORCE_LE( + TSIZE, 16, + platform::errors::InvalidArgument( + "The tile size should smaller than 16, but received is:%d.", TSIZE)); + + PADDLE_ENFORCE_EQ( + (TSIZE & (TSIZE - 1)), 0, + platform::errors::InvalidArgument( + "Data types should be powers of 2, but reived size is:%d.", TSIZE)); + + const int kMaxLongSideLen = 1024; + const int kMaxShortSideLen = 15; + + for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { + for (int short_side = 2; short_side <= kMaxShortSideLen; short_side += 1) { + if (CheckLongTileSize(long_side, short_side, TSIZE)) { + tiles->push_back(std::make_pair(long_side, short_side)); + + if (short_side == 2) return true; + + break; + } + } + } + return false; +} + +// Use system built in type +template +struct SystemElemType; +template <> +struct SystemElemType<1> { + using type = uint8_t; +}; +template <> +struct SystemElemType<2> { + using type = uint16_t; +}; +template <> +struct SystemElemType<4> { + using type = uint32_t; +}; +template <> +struct SystemElemType<8> { + using type = uint64_t; +}; +template <> +struct SystemElemType<16> { + using type = float4; +}; + +template +void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d, + int tile_size_i, int tile_size_j, + int total_tiles_count, const T* input, + const Dim3& input_dims, T* output) { + constexpr int NumThreads = tile_long; + if (tile_size_i <= tile_long && tile_size_j <= tile_short) { + TilingSwapDim1And2< + T, NumThreads, tile_long, + tile_short><<>>( + input, input_dims, output); + } else { + TilingSwapDim1And2< + T, NumThreads, tile_short, + tile_long><<>>( + input, input_dims, output); + } +} + +template +struct NarrowDims2TransposeDispatch { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2." + " But received value is:%d.", + tile_long)); + + bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && + std::min(tile_size_i, tile_size_j) <= tile_short; + + if (request_satisfied) { + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + const bool long_side_request_not_satisfied = + std::max(tile_size_i, tile_size_j) > tile_long; + + if (long_side_request_not_satisfied) { + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } else { + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } + } +}; + +// If Not long tile size, goto this function when compile. +template +struct NarrowDims2TransposeDispatch< + T, tile_long, tile_short, + typename std::enable_if< + CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2." + " But received value is:%d.", + tile_long)); + + bool request_satisfied = std::max(tile_size_i, tile_size_j) <= tile_long && + std::min(tile_size_i, tile_size_j) <= tile_short; + + if (request_satisfied) { + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + NarrowDims2TransposeDispatch::DoTranspose( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } +}; + +// If long tile size, goto this function when compile. +template +struct NarrowDims2TransposeDispatch< + T, tile_long, tile_short, + typename std::enable_if::type> { + static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, T* output) { + PADDLE_ENFORCE_EQ( + (tile_long & (tile_long - 1)), 0, + platform::errors::InvalidArgument( + "The length of the longer side of the tile should be power of 2," + " but received is:%d.", + tile_long)); + + LaunchNarrowDims2TransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } +}; + +template +void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input, + const Dim3& input_dims, T* output, + const int kMinTileSize) { + // First get available tile sizes for the data type requested as backups + std::vector> tile_sele; + auto ret = SelectProperTileSize(&tile_sele); + PADDLE_ENFORCE_EQ( + ret, true, + platform::errors::InvalidArgument( + "SelectProperTileSize should return true, but return value is:%d.", + ret)); + + int tile_long_edge = 0; + int tile_short_edge = 0; + float lowest_cost = std::numeric_limits::max(); + int input_long_edge = std::max(input_dims[1], input_dims[2]); + + // Find the tile size that best suit in inputs. + for (auto tile_size_pair : tile_sele) { + int proposed_tile_long_edge = tile_size_pair.first; + // data may not aligned to tile, so some threads wasted, we need + // to find least wasted threads, which means we need to find tile + // can split input properly, in another words: num_wasted_threads=0. + int num_wasted_threads = input_long_edge - + framework::CeilOrFloor( + input_long_edge, proposed_tile_long_edge) * + proposed_tile_long_edge; + + int num_full_tiles = framework::CeilOrFloor( + input_long_edge, proposed_tile_long_edge); + + float cost = num_wasted_threads; + + if (cost <= lowest_cost) { + tile_long_edge = proposed_tile_long_edge; + tile_short_edge = tile_size_pair.second; + lowest_cost = cost; + } + // break as we already find best tile size. + if (cost == 0) break; + } + + // The tile size we select should be match with input dim, long side to long + // short side to short. + // First set long side as i if dim1 > Tile min size, then set dim2 as j. + int select_tile_size_i = + input_dims[1] >= kMinTileSize ? tile_long_edge : input_dims[1]; + int select_tile_size_j = + input_dims[1] >= kMinTileSize ? input_dims[2] : tile_long_edge; + + // Check if i is long edge, if not set i as short. + select_tile_size_i = select_tile_size_i == tile_long_edge + ? tile_long_edge + : std::min(select_tile_size_i, tile_short_edge); + + // Check if j is long edge, if not set j as short. + select_tile_size_j = select_tile_size_j == tile_long_edge + ? tile_long_edge + : std::min(select_tile_size_j, tile_short_edge); + + // Here finally get proper long X short tile size. + Dim3 input_dims_aligned = { + input_dims[0], + framework::CeilOrFloor(input_dims[1], select_tile_size_i), + framework::CeilOrFloor(input_dims[2], select_tile_size_j), + }; + + int total_tiles_count = + input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + + // Suppose T can be replaced by system builtin types + using ElemType = typename SystemElemType::type; + + NarrowDims2TransposeDispatch::DoTranspose( + d, select_tile_size_i, select_tile_size_j, total_tiles_count, + reinterpret_cast(input), input_dims, + reinterpret_cast(output)); +} + +// This is for case that cannot do coalescing read and write. +// Or input is too small to split into tiles. +template +__global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, + Dim3 input_dims, T* __restrict__ output) { + Dim3 output_dims; + output_dims[pos0] = input_dims[0]; + output_dims[pos1] = input_dims[1]; + output_dims[pos2] = input_dims[2]; + + CUDA_KERNEL_LOOP(output_index, nthreads) { + Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims); + + Index3 input_tensor_index; + input_tensor_index[0] = output_tensor_index[pos0]; + input_tensor_index[1] = output_tensor_index[pos1]; + input_tensor_index[2] = output_tensor_index[pos2]; + + int input_index = FlatTensorIndex(input_tensor_index, input_dims); + + output[output_index] = input[input_index]; + } +} + +// Here suppose convert all tensor to dim3, so just change dim1 and 2. +template +void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, + const T* input, const Dim3& input_dims, + T* output) { + // Suppose tile size > 16 + static const int kMinTileSize = 16; + static const int kMinNarrowTileSize = 96; + + bool large_tile = + input_dims[1] >= kMinTileSize && input_dims[2] >= kMinTileSize; + bool narrow_tile = input_dims[1] >= kMinNarrowTileSize || + input_dims[2] >= kMinNarrowTileSize; + if (large_tile) { + // If input is large square, such as 32X32, use SM to do copy. + // suppose 32 X 32 gives best performance, and 8 warp in block. + constexpr int kTileSize = 32; + constexpr int kNumThreads = 256; + + Dim3 input_dims_aligned = { + input_dims[0], + framework::CeilOrFloor(input_dims[1], kTileSize), + framework::CeilOrFloor(input_dims[2], kTileSize), + }; + + int total_tiles_count = + input_dims_aligned[0] * input_dims_aligned[1] * input_dims_aligned[2]; + + TilingSwapDim1And2< + T, kNumThreads, kTileSize, + kTileSize><<>>( + input, input_dims, output); + + } else if (narrow_tile) { + // If input shape is like Rect, such as 2X100, use Narrow tile size. + // It makes things complicated, because need to find a tile can coverr + // input and also reach best coalescing. + SwapDim1And2InNarrow(d, input, input_dims, output, kMinTileSize); + } else { + // If input shape is small, such as 8X8, just do simple copy + int total_elements = input_dims[0] * input_dims[1] * input_dims[2]; + auto config = GetGpuLaunchConfig1D(d, total_elements); + TransposeSimpleKernel<<< + config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( + total_elements, input, input_dims, output); + } +} + +template +struct SwapDim1And2InTranspose { + typedef platform::CUDADeviceContext Device; + void operator()(const Device& d, const T* in, + const std::vector& combined_dims, T* out) { + Dim3 input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + SendSwapDim1And2InTranspose(d, in, input_dims, out); + } +}; + +template +struct SwapDim0And2InTranspose { + typedef platform::CUDADeviceContext Device; + void operator()(const Device& d, const T* in, + const std::vector& combined_dims, T* out) { + Dim3 input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + + size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; + auto config = GetGpuLaunchConfig1D(d, total_size); + + TransposeSimpleKernel<<< + config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( + total_size, in, input_dims, out); + } +}; + +// This function is to combine dimension. fox example: +// (0, 1, 3, 2) --> (0, 2, 1) +inline void CombineTransposeDim3(const framework::DDim& shape, + const std::vector& perm, + std::vector* new_perm, + framework::DDim* new_dims) { + PADDLE_ENFORCE_EQ(shape.size(), perm.size(), + platform::errors::InvalidArgument( + " shape should have the save dim with perm, but" + " received shape size is:%d, perm size is:%d.", + shape.size(), perm.size())); + + std::vector dim_vec; + if (shape.size() == 1) { + // If input dimension is already 1, no need to combine dim. + new_perm->resize(1); + (*new_perm)[0] = perm[0]; + dim_vec.push_back(shape[0]); + *new_dims = framework::make_ddim(dim_vec); + return; + } + std::vector new_dim_pos(shape.size(), -1); + std::vector combined_dims(shape.size(), 0); + int cur_head = perm[0]; + new_dim_pos[cur_head] = 0; + combined_dims[0] = shape[cur_head]; + int dim_idx = 0; + for (int perm_idx = 1; perm_idx < shape.size(); ++perm_idx) { + // combine consecutive dimensions. + if (cur_head + 1 == perm[perm_idx]) { + cur_head = perm[perm_idx]; + combined_dims[dim_idx] *= shape[cur_head]; + } else { + // Else start a new dimension. + cur_head = perm[perm_idx]; + dim_idx++; + new_dim_pos[cur_head] = dim_idx; + combined_dims[dim_idx] = shape[cur_head]; + } + } + + new_perm->resize(dim_idx + 1); + + dim_idx = 0; + for (int i = 0; i < new_dim_pos.size(); ++i) { + if (new_dim_pos[i] >= 0) { + int new_perm_idx = new_dim_pos[i]; + (*new_perm)[dim_idx] = new_perm_idx; + dim_vec.push_back(combined_dims[new_perm_idx]); + dim_idx++; + } + } + + *new_dims = framework::make_ddim(dim_vec); +} + +template +struct TransposeSimple { + static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in, + const std::vector perm, Tensor* out) { + // First reduce the dimensions of the input tensor if possible. + std::vector new_perm; + framework::DDim new_dims; + CombineTransposeDim3(in.dims(), perm, &new_perm, &new_dims); + + // Only use tile copy GPU kernel when dimension is 2 or 3. + int dims = new_dims.size(); + std::vector new_dim_vec = framework::vectorize(new_dims); + if (dims < 2 || dims > 3) return false; + auto in_data = in.data(); + auto out_data = out->data(); + // In most cases, dim will not greater than 3 after combine. + switch (dims) { + case 2: + if (new_perm[0] == 1 && new_perm[1] == 0) { + // Add the first dimension size as 1. + new_dim_vec.insert(new_dim_vec.begin(), 1); + SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } + break; + case 3: + // In this case, suppose we can do coalescing read and write in tile. + if (new_perm == std::vector({0, 2, 1})) { + SwapDim1And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } else if (new_perm == std::vector({2, 1, 0})) { + // Maybe can optimize later, find a way to do coalescing memory copy. + // But I think it depends on the data size. If span is not large, + // maybe + // can do coalescing. + SwapDim0And2InTranspose()(ctx, in_data, new_dim_vec, out_data); + return true; + } else { + return false; + } + break; + default: + return false; + } + return false; + } +}; + +template +void TransposeGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, + const int ndims, const Tensor& in, + const std::vector perm, Tensor* out) { + auto ret = TransposeSimple::run(dev_ctx, in, perm, out); + if (!ret) { + TransCompute(ndims, dev_ctx, in, out, perm); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/aligned_vector.h b/paddle/fluid/platform/aligned_vector.h new file mode 100644 index 0000000000000..7d014f6bdcb0b --- /dev/null +++ b/paddle/fluid/platform/aligned_vector.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.1 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.1 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace platform { + +// Aligned vector generates vectorized load/store on CUDA. +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; + + HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } + HOSTDEVICE inline T& operator[](int i) { return val[i]; } +}; + +template +HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { + const AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *vec = *addr_vec; +} + +template +HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { + AlignedVector* addr_vec = + reinterpret_cast*>(addr); + *addr_vec = vec; +} + +/* +* Only the address of input data is the multiplier of 1,2,4, vectorized load +* with corresponding multiplier-value is possible. Moreover, the maximum length +* of vectorized load is 128 bits once. Hence, valid length of vectorized load +* shall be determined under both former constraints. +*/ +template +int GetVectorizedSize(const T* pointer) { + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + uint64_t address = reinterpret_cast(pointer); + constexpr int vec8 = std::alignment_of>::value; // NOLINT + constexpr int vec4 = std::alignment_of>::value; // NOLINT + constexpr int vec2 = std::alignment_of>::value; // NOLINT + if (address % vec8 == 0) { + /* + * Currently, decide to deal with no more than 4 data once while adopting + * vectorization load/store, if performance test shows that dealing with + * 8 data once in vectorization load/store does get optimized, return code + * below can be changed into " return std::min(8, valid_vec_size); " . + */ + return std::min(4, valid_vec_size); + } else if (address % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_memory_aligment.cc b/paddle/fluid/platform/device_memory_aligment.cc index 383dbd23ca0a5..8261c866d073d 100644 --- a/paddle/fluid/platform/device_memory_aligment.cc +++ b/paddle/fluid/platform/device_memory_aligment.cc @@ -37,6 +37,9 @@ size_t Alignment(size_t size, const platform::Place &place, int align_size) { #endif } } + if (is_npu_place(place)) { + size += 32; // required by ascendcl + } size_t remaining = size % alignment; return remaining == 0 ? size : size + (alignment - remaining); } diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 561f20af45ab5..42583b60680b9 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -55,7 +55,13 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); #if CUDA_VERSION >= 9020 #define CUSOLVER_ROUTINE_EACH_R1(__macro) \ __macro(cusolverDnSpotrfBatched); \ - __macro(cusolverDnDpotrfBatched); + __macro(cusolverDnDpotrfBatched); \ + __macro(cusolverDnSgesvdj_bufferSize); \ + __macro(cusolverDnDestroyGesvdjInfo); \ + __macro(cusolverDnCreateGesvdjInfo); \ + __macro(cusolverDnDgesvdj_bufferSize); \ + __macro(cusolverDnSgesvdj); \ + __macro(cusolverDnDgesvdj); CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index 02f9d5441281c..f26c4fdd17ad7 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -15,22 +15,17 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/platform/hostdevice.h" +#include "paddle/fluid/platform/aligned_vector.h" #define INT_BITS 32 namespace paddle { namespace platform { -template -struct alignas(sizeof(T) * Size) CudaAlignedVector { - T val[Size]; -}; - struct FastDivMod { // 1st value represents the result of input number divides by recorded divisor // 2nd value represents the result of input number modulo by recorded divisor - using DivModT = CudaAlignedVector; + using DivModT = AlignedVector; FastDivMod() {} HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) { @@ -65,39 +60,5 @@ struct FastDivMod { uint32_t multiplier; }; -/* -* Only the address of input data is the multiplier of 1,2,4, vectorized load -* with corresponding multiplier-value is possible. Moreover, the maximum length -* of vectorized load is 128 bits once. Hence, valid length of vectorized load -* shall be determined under both former constraints. -*/ -template -int GetVectorizedSize(const T *pointer) { - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - uint64_t address = reinterpret_cast(pointer); - constexpr int vec8 = - std::alignment_of>::value; // NOLINT - constexpr int vec4 = - std::alignment_of>::value; // NOLINT - constexpr int vec2 = - std::alignment_of>::value; // NOLINT - if (address % vec8 == 0) { - /* - * Currently, decide to deal with no more than 4 data once while adopting - * vectorization load/store, if performance test shows that dealing with - * 8 data once in vectorization load/store does get optimized, return code - * below can be changed into " return std::min(8, valid_vec_size); " . - */ - return std::min(4, valid_vec_size); - } else if (address % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index f18eab3246547..0274a2cea8ef4 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -100,6 +100,7 @@ DEFINE_string( npu_config_path, "", "The absolute path of configuration json file, like: /tmp/config.json. " "If proveided, it will be passed to aclInit()."); +DEFINE_int32(min_loss_scaling, 1, "set minmum loss scaling value!"); #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index e6442ded6b5ae..370d9b3925226 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1090,9 +1090,9 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p"); } - std::shared_ptr AcquireSrcSubmemory( + std::shared_ptr AcquireSubmemory( const std::vector& dims, const std::vector& offset, - const std::shared_ptr& mem_p, int submemory_number) { + const std::shared_ptr& mem_p, int submemory_number = 0) { std::string local_key = key_; local_key.append("@submem") .append(std::to_string(submemory_number)) diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index a8b2962d4acaf..0989f2156877f 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -29,6 +29,8 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl2_ops() { // KL1支持的op,通过op_name, data_type, place来索引 static XPUOpMap s_xpu2_kernels{ + {"label_smooth", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"elementwise_sub", @@ -73,6 +75,10 @@ XPUOpMap& get_kl2_ops() { {"elementwise_min_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore }; diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index 4824a34e843bb..dd45443a04113 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -98,6 +98,8 @@ DECLARE_string(selected_xpus); #ifdef PADDLE_WITH_ASCEND_CL // device management DECLARE_string(selected_npus); +// set minmum loss scaling value +DECLARE_int32(min_loss_scaling); #endif #ifdef PADDLE_WITH_DISTRIBUTE @@ -385,6 +387,7 @@ static void RegisterGlobalVarGetterSetter() { #ifdef PADDLE_WITH_ASCEND_CL REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_npus); + REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_min_loss_scaling); #endif #ifdef PADDLE_WITH_DITRIBUTE diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 9e8fcd3ae82a3..d9c9ff4dec057 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -1702,6 +1702,17 @@ set -x } function parallel_test_base_npu() { + # skipping if no NPU related files changed + if [ ${SKIP_NPU_TEST:-ON} == "ON" ] ; then + fetch_upstream_develop_if_not_exist + git diff --name-only remotes/upstream/$BRANCH + npu_cc_changes=$(git diff --name-only remotes/upstream/$BRANCH | grep "op_npu.cc" || true) + npu_py_changes=$(git diff --name-only remotes/upstream/$BRANCH | grep "op_npu.py" || true) + if [ -z "${npu_cc_changes}" ] && [ -z "${npu_py_changes}" ] ; then + echo "NO NPU operators files changed, skip NPU unit tests!" + exit 0 + fi + fi mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build/python/paddle/fluid/tests/unittests/npu if [ ${WITH_TESTING:-ON} == "ON" ] ; then diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ce338275b2935..24a7a666fb4f8 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -100,6 +100,7 @@ from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import mv # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 +from .tensor.linalg import svd # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 @@ -493,6 +494,7 @@ 'sqrt', 'cholesky', 'matrix_power', + 'svd', 'randperm', 'linspace', 'reshape', diff --git a/python/paddle/distributed/auto_parallel/attribute.py b/python/paddle/distributed/auto_parallel/attribute.py index 0ca1b7e9444d0..879e94b83733c 100644 --- a/python/paddle/distributed/auto_parallel/attribute.py +++ b/python/paddle/distributed/auto_parallel/attribute.py @@ -14,6 +14,7 @@ import copy from collections import defaultdict +from paddle.fluid import core class TensorDistributedAttribute: @@ -77,6 +78,8 @@ def mark_as_parameter(self): self._is_parameter = True def is_valid(self): + if self.get_owner_tensor().type == core.VarDesc.VarType.READER: + return True tensor_shape = self.get_owner_tensor().desc.shape() if len(tensor_shape) != len(self.get_dims_mapping()): return False @@ -222,6 +225,8 @@ def mark_as_parameter(self, name): self._is_parameters[name] = True def is_valid(self): + if "read" in self.get_owner_op().type: + return True for name in self.get_owner_op().desc.input_arg_names(): dims_mapping = self.get_input_dims_mapping(name) shape = self.get_input_shape(name) diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py index ff2adc7eacf91..bddf93682557a 100644 --- a/python/paddle/distributed/auto_parallel/context.py +++ b/python/paddle/distributed/auto_parallel/context.py @@ -15,9 +15,11 @@ import copy from collections import defaultdict from paddle.fluid import framework +from paddle.fluid import core from .attribute import TensorDistributedAttribute from .attribute import OperatorDistributedAttribute from .utils import append_distributed_attr_suffix +from .interface import _g_process_mesh_map # There always exists a default context for user. And user can set it to another one. DEFAULT_DISTRIBUTED_CONTEXT = None @@ -49,6 +51,20 @@ def __init__(self): self._op_distributed_attr_map_for_program = {} self._tensor_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {} + # The following is a hard code and will be removed in the future + self._data_parallel_axis = None + self._model_parallel_axis = None + self._process_mesh = _g_process_mesh_map.get(0, None) + if self._process_mesh is not None: + if self._process_mesh.ndim == 1: + self._data_parallel_axis = 0 + self._model_parallel_axis = 0 + else: + self._data_parallel_axis = 0 + self._model_parallel_axis = 1 + else: + self._data_parallel_axis = -1 + self._model_parallel_axis = -1 def is_initialized_for_program(self): return self._is_initialized_for_program @@ -99,6 +115,19 @@ def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr): op_node_id = op_node.id() self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr + def set_process_mesh(self, process_mesh): + self._process_mesh = process_mesh + if self._process_mesh is not None: + if self._process_mesh.ndim == 1: + self._data_parallel_axis = 0 + self._model_parallel_axis = 0 + else: + self._data_parallel_axis = 0 + self._model_parallel_axis = 1 + else: + self._data_parallel_axis = -1 + self._model_parallel_axis = -1 + def initialize_distributed_attr_for_program(self, program): if self._is_initialized_for_program: return @@ -377,3 +406,11 @@ def amend_distributed_attr_for_program(self): if dims_mapping[i] != -1 and process_mesh_shape[ dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 + + def _get_data_parallel_info(self): + # This function is a hard code, and will be obsoleted in the future + return self._data_parallel_axis, self._process_mesh + + def _get_model_parallel_info(self): + # This function is a hard code, and will be obsoleted in the future + return self._model_parallel_axis, self._process_mesh diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 4c408345f1739..348edaef68198 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -184,6 +184,13 @@ def parent(self): "parent with id %d does not exist." % self._parent_id) return _g_process_mesh_map[self._parent_id] + @property + def ndim(self): + r""" + Get the number of dimension of ProcessMesh. + """ + return len(self._topology) + def set_placement(self, order): """ Set the map from logical processes to physical ones using the @@ -229,6 +236,13 @@ def set_placement(self, order): for idx, l_id in enumerate(logical_order): _user_defined_physical_map[l_id] = order[idx] + def _reset_global_process_mesh_map(self): + """ + Remove all process mesh in _g_process_mesh_map, make it empty. + """ + + _g_process_mesh_map = dict() + def __eq__(self, other): assert other and isinstance(other, ProcessMesh) if self.topology != other.topology or self.process_group != other.process_group: diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index c5e253c0e0b17..ef2f50834490f 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -33,6 +33,8 @@ def get_impls(self): class DistributedOperatorImpl: def __init__(self): self._name = None + self._forward_implemented = False + self._backward_implemented = False def forward(self, dist_ctx, *args, **kwargs): raise NotImplementedError("Please Implement this method in Subclass.") diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 81d3925bb5dcc..5d1cfcbf69e4d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -22,6 +22,12 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from ..process import new_process_group +from ..utils import _get_comm_group class DistributedEmbedding(DistributedOperator): @@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): def __init__(self, name): super(DistributedEmbeddingImpl, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -92,6 +100,110 @@ def update_dims_mapping(self, op_dist_attr): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['Ids'] + ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( + input_name_mapping['Ids']) + assert len( + input_name_mapping['W'] + ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( + input_name_mapping['W']) + assert len( + output_name_mapping['Out'] + ) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + + Ids_var = dst_block.var(input_name_mapping['Ids'][0]) + Weight_var = dst_block.var(input_name_mapping['W'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # got dist attribute info + embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( + Weight_var.name)[0] + process_mesh_shape = op_dist_attr.get_process_mesh().topology + process_mesh_group = op_dist_attr.get_process_mesh().process_group + + # caculate embedding offset + # TODO generalize here, using cartisian product to allow any dimensional mesh shape + mesh_shape = len(process_mesh_shape) + assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format( + process_mesh_shape) + num_partition = process_mesh_shape[embedding_row_dim_mapping] + # TODO generalize here, support any mesh group + if mesh_shape == 1: + relative_idx = process_mesh_group.index(rank_id) + else: + relative_idx = rank_id % num_partition + + per_part_size = Weight_var.shape[0] + relative_idx = relative_idx * per_part_size + + # TODO caculate ring id + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + # append op + check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], + 'c_embedding') + + intermediate_var_0 = dst_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_embedding", 'tmp'])), + dtype=Weight_var.dtype, + shape=Out_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=Out_var.stop_gradient) + + check_variable_and_dtype( + Out_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + 'c_allreduce_sum') + + dst_block.append_op( + type='c_embedding', + inputs={'Ids': [Ids_var], + 'W': [Weight_var]}, + outputs={'Out': [intermediate_var_0]}, + attrs={"start_index": relative_idx}) + + # use_model_parallel + dst_block.append_op( + type='c_allreduce_sum', + inputs={'X': [intermediate_var_0]}, + outputs={'Out': [Out_var]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + register_distributed_operator_impl("lookup_table_v2", DistributedEmbeddingImpl("row_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index fbeb0edd41897..9059feeaf8525 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -22,6 +22,12 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from ..process import new_process_group +from ..utils import _get_comm_group def _update_dims_mapping_for_matmul(op_dist_attr): @@ -37,7 +43,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): y_dims_mapping_len = len(y_dims_mapping) out_dims_mapping_len = len(out_dims_mapping) - # print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping) # Add dim mapping to Make sure the length dims_mapping be at least 2 if x_dims_mapping_len == 1: x_dims_mapping.insert(0, -1) @@ -109,7 +114,6 @@ def _update_dims_mapping_for_matmul(op_dist_attr): if y_dims_mapping_len == 1: y_dims_mapping.pop(1) - # print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping) assert len(x_dims_mapping) == x_dims_mapping_len assert len(y_dims_mapping) == y_dims_mapping_len assert len(out_dims_mapping) == out_dims_mapping_len @@ -131,6 +135,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl0, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -170,12 +176,101 @@ def update_dims_mapping(self, op_dist_attr): changed = True return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + intermediate_var_0 = dst_block.create_var( + name=unique_name.generate_with_ignorable_key(".".join( + ["c_identity", 'tmp'])), + dtype=X_var.dtype, + shape=X_var.shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=X_var.stop_gradient) + + check_variable_and_dtype( + X_var, 'tensor', + ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + + dst_block.append_op( + type='c_identity', + inputs={'X': [X_var]}, + outputs={'Out': intermediate_var_0}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + + check_variable_and_dtype(intermediate_var_0, 'x', + ['float16', 'float32', 'float64'], + 'linear') + check_dtype(intermediate_var_0.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} + dst_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': Out_var}, + attrs=attrs) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + # RowParallel class DistributedMatmulImpl1(DistributedOperatorImpl): def __init__(self, name): super(DistributedMatmulImpl1, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -217,6 +312,86 @@ def update_dims_mapping(self, op_dist_attr): changed = True return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 2, "col_parallel_linear take 2 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 1, "col_parallel_linear take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "col_parallel_linear input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['Y'] + ) == 1, "col_parallel_linear input Y take 1 variable but got {}".format( + input_name_mapping['Y']) + assert len( + output_name_mapping['Out'] + ) == 1, "col_parallel_linear input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + X_var = dst_block.var(input_name_mapping['X'][0]) + Weight_var = dst_block.var(input_name_mapping['Y'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + + # TODO infer logic comm presentation + model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( + )._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, rank_id) + group = new_process_group(group_ranks) + + check_variable_and_dtype( + X_var, 'x', ['float16', 'float32', 'float64'], 'linear') + check_dtype(X_var.dtype, 'dtype', + ['float16', 'float32', 'float64'], 'linear') + attrs = { + 'transpose_X': False, + 'transpose_Y': False, + 'alpha': 1, + } + inputs = {'X': X_var, 'Y': Weight_var} + intermediate_var_0 = dst_block.create_var( + shape=Out_var.shape, + dtype=Out_var.dtype, + type=Out_var.type, + lod_level=Out_var.lod_level, + persistable=False, + is_data=False, + need_check_feed=Out_var.desc.need_check_feed()) + dst_block.append_op( + type='matmul', + inputs=inputs, + outputs={'Out': intermediate_var_0}, + attrs=attrs) + + dst_block.append_op( + type='c_allreduce_sum', + inputs={'X': intermediate_var_0}, + outputs={'Out': Out_var}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + # ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 40da0e2f6093f..e7fbe9cfebad8 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -22,6 +22,10 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype class DistributedReshape2(DistributedOperator): @@ -37,6 +41,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): def __init__(self, name): super(DistributedReshapeImpl0, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -91,11 +97,90 @@ def update_dims_mapping(self, op_dist_attr): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['ShapeTensor'] + ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( + input_name_mapping['ShapeTensor']) + assert len( + input_name_mapping['Shape'] + ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( + input_name_mapping['Shape']) + assert len( + output_name_mapping['Out'] + ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + assert len( + output_name_mapping['XShape'] + ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( + input_name_mapping['XShape']) + + X_var = dst_block.var(input_name_mapping['X'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + XShape_var = dst_block.var(output_name_mapping['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in input_name_mapping['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in input_name_mapping['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + dst_block._sync_with_cpp() + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + class DistributedReshapeImpl1(DistributedOperatorImpl): def __init__(self, name): super(DistributedReshapeImpl1, self).__init__() self._name = name + self._forward_implemented = True + self._backward_implemented = False def is_process_mesh_compatible(self, op_dist_attr): """ No restriction for now. """ @@ -150,6 +235,83 @@ def update_dims_mapping(self, op_dist_attr): return changed + def forward(self, serial_op): + def static_handle(dst_block, + src_op, + op_dist_attr, + input_name_mapping, + output_name_mapping, + rank_id=0): + assert len( + input_name_mapping + ) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( + input_name_mapping) + assert len( + output_name_mapping + ) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format( + output_name_mapping) + assert len( + input_name_mapping['X'] + ) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( + input_name_mapping['X']) + assert len( + input_name_mapping['ShapeTensor'] + ) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( + input_name_mapping['ShapeTensor']) + assert len( + input_name_mapping['Shape'] + ) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( + input_name_mapping['Shape']) + assert len( + output_name_mapping['Out'] + ) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( + input_name_mapping['Out']) + assert len( + output_name_mapping['XShape'] + ) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( + input_name_mapping['XShape']) + + X_var = dst_block.var(input_name_mapping['X'][0]) + Out_var = dst_block.var(output_name_mapping['Out'][0]) + XShape_var = dst_block.var(output_name_mapping['XShape'][0]) + shape_list = src_op.desc.attr("shape") + ShapeTensor_var_list = [] + for name in input_name_mapping['ShapeTensor']: + ShapeTensor_var_list.append(name) + Shape_var_list = [] + for name in input_name_mapping['Shape']: + Shape_var_list.append(name) + + # got dist attribute info + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.get_process_mesh().topology + + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = shape_list[idx] // process_mesh_shape[ + axis] + + # create op + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) + new_op_desc.set_input('Shape', Shape_var_list) + new_op_desc.set_input('X', [X_var.name]) + new_op_desc.set_output('XShape', [XShape_var.name]) + new_op_desc.set_output('Out', [Out_var.name]) + new_op_desc._set_attr('shape', shape_list) + + dst_block._sync_with_cpp() + + if in_dygraph_mode(): + raise NotImplementedError( + "Dist op for [{}] with idx [{}] is NOT implemented yet.".format( + "matmul", 0)) + else: + return static_handle + register_distributed_operator_impl("reshape2", DistributedReshapeImpl0("add_one_dim_back")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index fad11aadf8020..dc78bdee1fb14 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -47,7 +47,6 @@ def is_input_compatible(self, op_dist_attr): x_name = op_desc.input('X')[0] axis = op_desc.attr('axis') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - # print("softmax axis", axis) if axis != -1 and axis != len(x_dims_mapping) - 1: return False diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py new file mode 100755 index 0000000000000..03497f2967c80 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -0,0 +1,925 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid import framework as framework +from paddle.fluid import core, unique_name +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ +from paddle.distributed.auto_parallel.operators.common import get_distributed_operator +from paddle.distributed.auto_parallel.operators.common import find_best_compatible_distributed_operator_impl +from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm +from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from .process import new_process_group +from .interface import _g_process_mesh_map +from .utils import _get_comm_group + +__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] + + +class Partitioner(object): + """ + warning:: Partitioner is experimental and subject to change. + + Partitioner convert a program into another program. + Given a serial program which has been auto completed with shard annotation, the Partitioner + convert the serial program into a "distributed" program. The Partitioner will modify the serial + program in following two ways, which is also the major difference between serial and distributed program: + 1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation + 2. partition var: if a var is sharded, modify the shape of var according to its shard annotation + + Partitioner is supposed to be call by the auto parallel framework, and not supposed to be directly called by user. + + Example: + .... + import paddle.distributed.auto_parallel as auto + from paddle.fluid.distributed_attribute import get_default_distributed_context + from paddle.distributed import fleet + from paddle.distributed.auto_parallel.partitioner import Partitioner + + # create serial program with forward only + with static.program_guard(serial_main_program, serial_start_program): + model = create_model(config) + tokens = static.data(name="tokens", shape=[batch_size, sequence_len], dtype='int64') + labels = static.data(name="labels", shape=[batch_size, sequence_len], dtype='int64') + loss_mask = static.data(name="loss_mask", shape=[batch_size, sequence_len], dtype='int64') + preds = model(tokens) + loss = criterion(preds, labels, loss_mask) + + # auto completion + auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7]) + annotated_main_program = auto.complete_annotation(serial_main_program) + auto_paralle_context = get_default_distributed_context() + + # distributed strategy & rank info + rank_id = paddle.distributed.get_rank() + dist_strategy = fleet.DistributedStrategy() + + # create partitioner + Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id) + + # create dist program with forward only + # for distributed inference, using partitioned_main_prog from here + partitioned_main_prog, partitioned_startup_prog = Partitioner.transpile_forward(complete_train_program, start_program) + + # create dist program with forward/backward/update + # for distributed training, using partitioned_main_prog from here + dist_params_grads = Partitioner.apply_backward(loss, complete_train_program, start_program, partitioned_main_prog, partitioned_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) + """ + + def __init__(self, dist_strategy, auto_parallel_context, rank_id=0): + """ + Args: + dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy. + auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. + rank_id (int): global rank id to which the partitioned distributed program belong. + """ + + if not isinstance(dist_strategy, DistributedStrategy): + raise TypeError( + "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here" + % type(dist_strategy)) + + if not isinstance(auto_parallel_context, DistributedContext): + raise TypeError( + "auto_parallel_context be paddle.fluid.DistributedContext, got %s here" + % type(auto_parallel_context)) + + self._dist_strategy = dist_strategy + self._auto_parallel_context = auto_parallel_context + self._rank_id = rank_id + self._serial2dist_varname_mapping = {} + self._dist_varname_suffix = "" + + # TODO if there is some dist op that is not compatible + # with auto_backward in forward, the following flag + # should be set to False + self._compatible_with_auto_backward = True + + # data parallelism + self._enable_data_parallel = False + self._dp_degree = 0 + self._dp_group = None + + # tensor parallelism + self._enable_tensor_parallel = False + self._tp_degree = 0 + self._tp_group = None + + def transpile_forward(self, serial_main_program, serial_startup_program): + """ + take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones. + instead of modify the input programs inplace, this function will preserve the inputs and create new program for output. + + beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if + those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this + function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize(). + + by now the fleet.distributed_strategy that need transpile forward program are following: + 1. (optimizer) sharding + + Args: + main_program (paddle.fluid.framework.program): serial main program with forward network only + startup_program (paddle.fluid.framework.program): serial startup program with forward network only + + return: + main_program (paddle.fluid.framework.program): distributed main program with forward network only + startup_program (paddle.fluid.framework.program): distributed startup program with forward network only + """ + + dist_main_program, dist_startup_program = self.transpile_forward_impl( + serial_main_program, serial_startup_program) + return dist_main_program, dist_startup_program + + def apply_backward(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + A complete training neural network is made up of forward and backward propagation. + This function is to generate the dist backward program for the distributed forward program. + + By now, the current automatical backward mechanism in paddle framework might NOT handle the backward generation for + some dist ops correctly, some so we now have two ways to genenate the backward program: + 1. dist_forward_program --> auto_backward --> dist_backward_program (if auto_backward could handle all dist op) + 2. serial_forward_program --> auto_backward --> serial_backward_program --> dist_op_backward_transpile --> dist_backward_program (if auto_backward could not handle all dist op) + + the backprogram is append the input dist program inplaced. + + Args: + serial_loss (Variable) the loss in serial program that to be minimized + serial_main_program (paddle.fluid.framework.program): serial main program with forward network only + serial_startup_program (paddle.fluid.framework.program): serial startup program with forward network only + dist_main_program (paddle.fluid.framework.program): dist main program with forward network only + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward network only + parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update + to minimize ``loss``. The default value is None, at this time all parameters + will be updated. + no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need + to be updated. The default value is None. + callbacks (list, optional): list of callable objects to run when appending backward + operator for one parameter. The default value is None. + + return: + params_grads (list) list of tuple that contain param and its grad variable + """ + params_grads = self.apply_backward_impl( + serial_loss, serial_main_program, serial_startup_program, + dist_main_program, dist_startup_program) + return params_grads + + def apply_optimize(self, user_define_optimizer, params_grads, + dist_main_program, dist_startup_program): + """ + append update related ops to the program: clip, weight decay, ops + filter optimize op if sharding is enable + naive gradient synchronization before update + + Args: + user_define_optimizer (paddle.fluid.optimizer): + params_grads (list) list of tuple that contain param and its grad variable + dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network + """ + + optimize_ops = self.apply_optimize_impl(user_define_optimizer, + params_grads, dist_main_program, + dist_startup_program) + + return optimize_ops + + def transpile_forward_impl(self, main_program, startup_program): + + if not isinstance(main_program, (Program)): + raise TypeError( + "dist_strategy be paddle.fluid.framework.program, got %s here" % + type(main_program)) + + if not isinstance(startup_program, (Program)): + raise TypeError( + "auto_parallel_context be paddle.fluid.framework.program, got %s here" + % type(startup_program)) + + # check if shard annotated serial program valid + if not self._is_valid_annotated_program(main_program): + raise RuntimeError( + "Not all vars or ops are annotated in main program !") + + # determine parallelism mode + self._determine_parallel_mode(main_program) + + # dist op & partition vars + new_main_prog, new_startup_program = self._dist_var_op_forward_transpile( + main_program, startup_program) + + # Sharding + if self._dist_strategy.sharding: + new_main_prog, new_startup_program = self._sharding_forward_transpile( + new_main_prog, new_startup_program) + + return new_main_prog, new_startup_program + + def apply_backward_impl(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + """ + + params_grads = self._dist_var_op_backward_transpile( + serial_loss, serial_main_program, serial_startup_program, + dist_main_program, dist_startup_program) + # Sharding + if self._dist_strategy.sharding: + self._sharding_backward_transpile(new_main_prog, + new_startup_program) + + # Data Parallel pass + if self._enable_data_parallel: + self._gradient_sync_transpile(dist_main_program, + dist_startup_program) + + return params_grads + + def apply_optimize_impl(self, user_define_optimizer, params_grads, + dist_main_program, dist_startup_program): + """ + append update related ops to the program: clip, weight decay, ops + filter optimize op if sharding is enable + naive gradient synchronization before update + + Args: + user_define_optimizer (paddle.fluid.optimizer): + params_grads (list) list of tuple that contain param and its grad variable + dist_main_program (paddle.fluid.framework.program): dist main program with forward & backward network + dist_startup_program (paddle.fluid.framework.program): dist startup program with forward & backward network + """ + + if self._dist_strategy.sharding: + params_grads = sharding_optimize_transpile( + params_grads, dist_main_program, dist_startup_program) + + optimize_ops = self._optimize_transpile(user_define_optimizer, + params_grads, dist_main_program, + dist_startup_program) + + return optimize_ops + + def _dist_var_op_forward_transpile(self, + serial_main_program, + serial_startup_program=None): + """ + 1. partition variables + 2. replace local op with corresponding dist op + """ + + partitioned_main_prog = fluid.Program() + partitioned_global_block = partitioned_main_prog.global_block() + serial_global_block = serial_main_program.global_block() + serial_ops = serial_main_program.global_block().ops + + # transpile main program + for op in serial_ops: + + # partititon input variables + for serial_input_varname in op.desc.input_arg_names(): + if serial_input_varname not in self._serial2dist_varname_mapping: + new_varname = serial_input_varname + self._dist_varname_suffix + if serial_global_block.has_var(serial_input_varname): + _partition_var(self._auto_parallel_context, + serial_global_block, + partitioned_global_block, + serial_input_varname, new_varname) + else: + assert serial_input_varname in __varname_not_in_block__ + + self._serial2dist_varname_mapping[ + serial_input_varname] = new_varname + + # partition output vars + for serial_output_varname in op.desc.output_arg_names(): + if serial_output_varname not in self._serial2dist_varname_mapping: + new_varname = serial_output_varname + self._dist_varname_suffix + _partition_var(self._auto_parallel_context, + serial_global_block, + partitioned_global_block, + serial_output_varname, new_varname) + self._serial2dist_varname_mapping[ + serial_output_varname] = new_varname + + # partition op + if _found_match_dist_op(self._auto_parallel_context, op): + # replace with corresponding dist op + _insert_dist_op(op, partitioned_global_block, + self._serial2dist_varname_mapping, + self._auto_parallel_context, self._rank_id) + else: + # replicate op + _insert_src_op(op, partitioned_global_block, + self._serial2dist_varname_mapping) + + # transpile startup program + if serial_startup_program == None: + partitioned_startup_prog = None + else: + partitioned_startup_prog = fluid.Program() + # create parameter + partitioned_startup_global_block = partitioned_startup_prog.global_block( + ) + param2shape = {} + for var in partitioned_main_prog.list_vars(): + if isinstance(var, Parameter): + _partition_parameter(self._auto_parallel_context, var, + partitioned_startup_global_block, + var.name, var.shape) + param2shape[var.name] = var.shape + + # copy initializer + for op in serial_startup_program.global_block().ops: + output_vars = op.desc.output_arg_names() + assert len( + output_vars + ) == 1, "initializer should output only ONE variable, but got [{}]".format( + str(op.desc)) + assert self._serial2dist_varname_mapping[output_vars[ + 0]] in param2shape, "try to initialize [{}] which is not a Parameter".format( + output_vars[0]) + new_op_desc = partitioned_startup_global_block.desc.append_op() + new_op_desc.copy_from(op.desc) + new_op_desc._rename_output( + output_vars[0], + self._serial2dist_varname_mapping[output_vars[0]]) + new_op_desc._set_attr("shape", param2shape[ + self._serial2dist_varname_mapping[output_vars[0]]]) + partitioned_startup_global_block._sync_with_cpp() + + # MP broadcast not split parameter + # NOTE Theoretically, the MP param init broadcast should be handled by + # each dist op itself. but if we insert the broadcast op at that moment, the broadcast + # will before the initializer, which lead to a undertermined case. + if self._enable_tensor_parallel: + param_to_sync = [] + for param in partitioned_startup_prog.all_parameters(): + if not self._is_var_distributed(param): + param_to_sync.append(param) + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + partitioned_startup_global_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self._tp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': param_to_sync}, + outputs={'Out': param_to_sync}, + attrs={ + 'ring_id': self._tp_group.id, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block._sync_with_cpp() + + # DP init param broadcast + if self._enable_data_parallel: + # parameters initialization synchronization + param_to_sync = [] + + for param in partitioned_startup_global_block.all_parameters(): + param_to_sync.append(param) + + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + partitioned_startup_global_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self._dp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': param_to_sync}, + outputs={'Out': param_to_sync}, + attrs={ + 'ring_id': self._dp_group.id, + OP_ROLE_KEY: OpRole.Forward + }) + partitioned_startup_global_block._sync_with_cpp() + + return partitioned_main_prog, partitioned_startup_prog + + def _dist_var_op_backward_transpile(self, + serial_loss, + serial_main_program, + serial_startup_program, + dist_main_program, + dist_startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + so far, the auto_backward case only guarantee the correcotness of backward ops for curtain Dist ops: + 1. NV-Megatron-like parallel embedding + 2. NV-Megatron-like row parallel linear + 3. NV-Megatron-like col parallel linear + """ + + if self._compatible_with_auto_backward: + assert isinstance( + serial_loss, Variable), "The target loss should be an Variable." + dist_loss = self._serial_varname2dist_var(serial_loss.name, + dist_main_program) + + assert len(dist_loss.shape) == 1 and dist_loss.shape[0] == 1, \ + "The dist loss.shape should be (1L,), but the current dist loss.shape is {}. " \ + "Maybe that you should call fluid.layers.mean to process the current loss.".format( + dist_loss.shape) + + # update parameter list + if parameter_list: + parameter_list = [ + self._serial_varname2dist_var(param.name, dist_main_program) + for param in parameter_list + ] + + # update parameter no_grad_set + if no_grad_set: + no_grad_set = [ + self._serial_varname2dist_var(param.name, dist_main_program) + for param in no_grad_set + ] + + return _auto_backward( + dist_loss, + dist_startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set, + callbacks=callbacks) + # replace dist grad ops + else: + raise RuntimeError("transpile NOT implemented !") + + def _optimize_transpile(self, user_define_optimizer, params_grads, + main_program, startup_program): + + with program_guard(main_program, startup_program): + optimize_ops = user_define_optimizer.apply_gradients(params_grads) + + return optimize_ops + + def _is_valid_annotated_program(self, program): + + # TODO (ZJ-LIANG) should check all block + ops = program.global_block().ops + vars_ = program.list_vars() + op_dist_attrs = [ + self._auto_parallel_context.get_op_distributed_attr_for_program(op) + for op in ops + ] + var_dist_attrs = [ + self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) for var in vars_ + ] + + all_ops_annotated = all(dist_attr is not None + for dist_attr in op_dist_attrs) + all_vars_annotated = all(dist_attr is not None + for dist_attr in var_dist_attrs) + + return all_ops_annotated and all_vars_annotated + + def _serial_varname2dist_var(self, serial_varname, dist_program): + assert serial_varname in self._serial2dist_varname_mapping, "The serial var [{}] is not found in var name mapping".format( + serial_varname) + dist_varname = self._serial2dist_varname_mapping[serial_varname] + + assert dist_program.global_block().has_var( + dist_varname + ), "The dist var [{}] is not found in dist program".format(dist_varname) + dist_var = dist_program.global_block().var(dist_varname) + + return dist_var + + def _determine_parallel_mode(self, program): + """ + determine the parallelism that is enabled + NOTE a hard rule and should be updated in future + """ + + for param in program.all_parameters(): + if self._is_var_distributed(param): + self._enable_tensor_parallel = True + break + + for var in program.list_vars(): + var_dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) + if not var_dist_attr.is_parameter(): + mapping = var_dist_attr.get_dims_mapping() + mesh = var_dist_attr.get_process_mesh().topology + if mapping[0] >= 0 and mesh[mapping[0]] > 1: + self._enable_data_parallel = True + break + + # tensor parallelism + if self._enable_tensor_parallel: + model_parallel_axis, process_mesh = self._auto_parallel_context._get_model_parallel_info( + ) + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + model_parallel_axis, self._rank_id) + self._tp_degree = len(group_ranks) + self._tp_group = new_process_group(group_ranks) + + # data parallelism + data_parallel_axis, process_mesh = self._auto_parallel_context._get_data_parallel_info( + ) + if self._enable_data_parallel: + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, + data_parallel_axis, self._rank_id) + self._dp_degree = len(group_ranks) + self._dp_group = new_process_group(group_ranks) + + def _is_var_distributed(self, var): + + dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( + var) + assert dist_attr is not None, "dist_attr of var [{}] is None".format( + var.name) + return _is_distributed(dist_attr) + + def _sharding_forward_transpile(self, main_prog, startup_program): + """ + this transpile conduct the modification in forward program need by sharding strategy + which majorly include: + 1. partition the parameter + 2. insert broadcast op + 3. insert sync op + + NOTE the transpile modification is inplace on the input program + """ + + raise NotImplementedError( + "Sharding is NOT support in AutoParallel yet!") + + def _sharding_backward_transpile(self, main_prog, startup_program): + """ + this transpile conduct the modification in backward program need by sharding strategy + which majorly include: + 1. partition the gradient + 2. insert broadcast op + 3. insert sync op + + NOTE the transpile modification is inplace on the input program + """ + + raise NotImplementedError( + "Sharding is NOT support in AutoParallel yet!") + + def _sharding_optimize_transpile(self, params_grads, dist_main_program, + dist_startup_program): + """ + shard params_grads + append the broadcast to sync parameters + """ + raise RuntimeError("sharding transpile is NOT implemented !") + + def _gradient_sync_transpile(self, main_program, startup_program): + """ + append the gradient allreduce ops for all parameters' grad in case of Data Parallel + """ + + # scale loss by dp degree + main_global_block = main_program.global_block() + for idx, op in reversed(list(enumerate(main_global_block.ops))): + if is_loss_grad_op(op): + loss_grad_var = main_global_block.vars[op.output_arg_names[0]] + main_global_block._insert_op_without_sync( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / self._dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + break + main_global_block._sync_with_cpp() + + # gradient synchronization + # NOTE naive gradient sync without overlapping + # so there is not need to sync between calc and comm + # collecting grad var + grad_to_sync = [] + for idx, op in reversed(list(enumerate(main_global_block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[i + 1] + assert (reduced_grad not in grad_to_sync) + grad_to_sync.append(reduced_grad) + if is_optimizer_op(op): + first_optimize_op_idx = idx + + # insert allreduce + for grad in grad_to_sync: + # FIXME the ring id should be set by autoparallel.mapping module + # it should be determined by dp groups butfixed it here for hacking + main_global_block.append_op( + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': self._dp_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward + }) + main_global_block.append_op( + type='c_sync_comm_stream', + inputs={'X': grad_to_sync}, + outputs={'Out': grad_to_sync}, + attrs={'ring_id': self._dp_group.id, + OP_ROLE_KEY: OpRole.Backward}) + main_global_block._sync_with_cpp() + + +def _get_no_grad_set_name(no_grad_set): + no_grad_set_name = set() + if no_grad_set is not None: + if isinstance(no_grad_set, (set, list, tuple)): + for i, no_grad_var in enumerate(no_grad_set): + if isinstance(no_grad_var, framework.Variable): + no_grad_set_name.add(no_grad_var.name) + elif isinstance(no_grad_var, six.string_types): + no_grad_set_name.add(no_grad_var) + else: + raise TypeError( + "The type of no_grad_set's member must be paddle.fluid.Variable or str, but received %s." + % (type(no_grad_var))) + else: + raise TypeError( + "The type of no_grad_set should be set or list or tuple, but received {}". + format(type(no_grad_set))) + return no_grad_set_name + + +def _get_no_grad_set(loss, no_grad_set=None): + no_grad_set = _get_no_grad_set_name(no_grad_set) + parameters = loss.block.program.global_block().all_parameters() + param_no_trainable = set( + [param.name for param in parameters if param.trainable is False]) + # If the parameter is no trainable, it should not have a gradient. + no_grad_set.update(param_no_trainable) + + return no_grad_set + + +def _found_match_dist_op(auto_paralle_context, op): + dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) + dist_ops = get_distributed_operator(op.type) + + return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ + dist_attr.get_impl_idx())._forward_implemented + + +def _auto_backward(loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + modification is inplaced + """ + act_no_grad_set = _get_no_grad_set(loss, no_grad_set) + assert isinstance(loss, Variable), "The target loss should be an Variable." + + if callbacks is None: + callbacks = [error_clip_callback] + else: + assert (isinstance(callbacks, list)) + + assert len(loss.shape) == 1 and loss.shape[0] == 1, \ + "The loss.shape should be (1L,), but the current loss.shape is {}. " \ + "Maybe that you should call fluid.layers.mean to process the current loss.".format( + loss.shape) + + program = loss.block.program + with program_guard(program, startup_program): + params_grads = append_backward(loss, parameter_list, act_no_grad_set, + callbacks) + + return params_grads + + +def _is_distributed(dist_attr): + + mapping = dist_attr.get_dims_mapping() + mesh = dist_attr.get_process_mesh().topology + for idx in range(len(mapping)): + if mapping[idx] >= 0 and mesh[mapping[idx]] > 1: + return True + + return False + + +def _get_dist_shape(var, dist_attr): + + var_shape = var.shape + mapping = dist_attr.get_dims_mapping() + mesh = dist_attr.get_process_mesh().topology + assert len(var_shape) == len( + mapping + ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( + var_shape, mapping) + new_shape = [] + for idx in range(len(var_shape)): + if var_shape[idx] == -1 or mapping[idx] == -1: + new_shape.append(var_shape[idx]) + else: + assert var_shape[idx] % mesh[mapping[ + idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( + var_shape[idx], mesh[mapping[idx]]) + new_shape.append(var_shape[idx] // mesh[mapping[idx]]) + + return new_shape + + +def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, + dst_shape): + # NOTE hack to copied Parameter + # not initialized parameter, need to initialize it + copied_kwargs = {} + copied_kwargs['trainable'] = src_var.trainable + copied_kwargs['optimize_attr'] = src_var.optimize_attr + copied_kwargs['regularizer'] = src_var.regularizer + copied_kwargs['do_model_average'] = src_var.do_model_average + copied_kwargs['need_clip'] = src_var.need_clip + + param = Parameter( + block=dst_block, + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + **copied_kwargs) + + # set dist attr uid + # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() + # param.desc.set_distributed_attr_uid(distributed_attr_uid) + dist_attr = copy.deepcopy( + auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_attr._owner_tensor = param + dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var)._owner_context + auto_paralle_context.set_tensor_distributed_attr_for_program(param, + dist_attr) + + +def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, + dst_varname, dst_shape): + var = dst_block.create_var( + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer) + + # set dist attr uid + # distributed_attr_uid = src_var.desc.get_distributed_attr_uid() + # var.desc.set_distributed_attr_uid(distributed_attr_uid) + dist_attr = copy.deepcopy( + auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) + dist_attr._owner_tensor = var + dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var)._owner_context + auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) + + +def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, + dst_varname): + """ + partition include: split + replicate + """ + src_var = src_block.var(src_varname) + + if src_var.type == core.VarDesc.VarType.READER: + dst_block.create_var( + type=src_var.type, + name=dst_varname, + persistable=True, + stop_gradient=True) + else: + dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( + src_var) + target_shape = _get_dist_shape(src_var, dist_attr) + + if isinstance(src_var, Parameter): + _partition_parameter(auto_paralle_context, src_var, dst_block, + dst_varname, target_shape) + else: + _partition_intermediate_var(auto_paralle_context, src_var, + dst_block, dst_varname, target_shape) + + +def _insert_src_op(src_op, dst_block, varname_mapping): + + new_op_desc = dst_block.desc.append_op() + new_op_desc.copy_from(src_op.desc) + for local_varname in src_op.desc.input_arg_names(): + new_op_desc._rename_input(local_varname, varname_mapping[local_varname]) + for local_varname in src_op.desc.output_arg_names(): + new_op_desc._rename_output(local_varname, + varname_mapping[local_varname]) + dst_block._sync_with_cpp() + + +def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, + rank_id): + + # build input varname mapping + input_mapping = {} + for input_name in src_op.desc.input_names(): + varnames = [] + for varname in src_op.desc.input(input_name): + varnames.append(varname_mapping[varname]) + input_mapping[input_name] = varnames + + # build output varname mapping + output_mapping = {} + for output_name in src_op.desc.output_names(): + varnames = [] + for varname in src_op.desc.output(output_name): + varnames.append(varname_mapping[varname]) + output_mapping[output_name] = varnames + + # append dist op + dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op) + dist_ops = get_distributed_operator(src_op.type) + append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward( + src_op) + append_op_handle( + dst_block, + src_op, + dist_attr, + input_mapping, + output_mapping, + rank_id=rank_id) diff --git a/python/paddle/distributed/auto_parallel/process.py b/python/paddle/distributed/auto_parallel/process.py new file mode 100644 index 0000000000000..b919645b96ccc --- /dev/null +++ b/python/paddle/distributed/auto_parallel/process.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import paddle +import paddle.fluid.core as core +from ..collective import _get_global_env +from ..collective import _new_ring_id +from ...fluid.framework import in_dygraph_mode +from ...fluid.layers.tensor import fill_constant + +LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None +PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None + + +def get_all_logical_process_set(): + from .interface import _g_process_mesh_map + all_logical_process_set = set(_g_process_mesh_map[0].process_group) + return all_logical_process_set + + +def get_logical_process_to_physical_process_map(): + global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + + +def set_logical_process_to_physical_process_map(mapping): + global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP + LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping + + +def get_processor_to_physical_process_map(): + global PROCESSOR_TO_PHYSICAL_PROCESS_MAP + return PROCESSOR_TO_PHYSICAL_PROCESS_MAP + + +def set_processor_to_physical_process_map(mapping): + global PROCESSOR_TO_PHYSICAL_PROCESS_MAP + PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping + + +PROCESS_GROUP_MAP = {} + + +def get_all_process_groups(): + global PROCESS_GROUP_MAP + return PROCESS_GROUP_MAP.values() + + +def new_process_group(ranks): + global PROCESS_GROUP_MAP + if not PROCESS_GROUP_MAP: + genv = _get_global_env() + PROCESS_GROUP_MAP["global_group"] = ProcessGroup( + 0, list(range(genv.world_size))) + # A key constructed from ranks is used in the global process group map + key = ''.join(map(str, sorted(ranks))) + if key not in PROCESS_GROUP_MAP: + num_groups = len(PROCESS_GROUP_MAP) + # Note: our process group may interfere with the original implementation + # so the created group id should start from the original _new_ring_id() + group_id = _new_ring_id() + num_groups + 1 + pg = ProcessGroup(group_id, ranks) + PROCESS_GROUP_MAP[key] = pg + return pg + else: + pg = PROCESS_GROUP_MAP[key] + return pg + + +# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py, +# Fleet also has a collective helper which uses ops to initialize communication in +# Paddle/python/paddle/distributed/fleet/meta_optimizers/common.py. We use the first one +# because it seems simple. This should be enhanced to manage the process membership and +# the instantiation process in a more general way. In the future, the process group may +# handle the communication implementation choice. +class ProcessGroup: + def __init__(self, group_id, ranks): + self._group_id = group_id + self._ranks = sorted(ranks) + self._nranks = len(self._ranks) + self._is_instantiate = False + + @property + def id(self): + return self._group_id + + # @property + # def key(self): + # return ''.join(map(str, sorted(self._ranks))) + + def local_rank(self, global_rank): + if global_rank in self._ranks: + return self._ranks.index(global_rank) + else: + assert False, \ + "Rank {} doesn't belong to this group".format(global_rank) + + def is_instantiate(self): + return self._is_instantiate + + def instantiate(self): + if self._is_instantiate: + return + ring_id = self.id + genv = _get_global_env() + global_rank = genv.rank + + if self._nranks >= 2: + strategy = core.ParallelStrategy() + strategy.nranks = self._nranks + strategy.local_rank = self.local_rank(global_rank) + strategy.trainer_endpoints = [ + genv.trainer_endpoints[i] for i in self._ranks + ] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, + place).init_with_ring_id(ring_id) + else: + assert False, ("No CUDA device found") + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + tmp = paddle.to_tensor( + [1], dtype="int32") if in_dygraph_mode() else fill_constant( + [0], dtype="int32", value="1") + paddle.distributed.all_reduce(tmp, use_calc_stream=True) + paddle.distributed.wait(tmp) + + self._is_instantiate = True + + def __str__(self): + string = "id: {}, nranks: {}, ranks: {}.".format( + self.id, self._nranks, ", ".join(map(str, self._ranks))) + return string diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py old mode 100644 new mode 100755 index a4a73ae5c0a64..c864375271b3c --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -14,6 +14,7 @@ import threading import paddle.fluid.core as core +import numpy as np def is_valid_list_index(list, index): @@ -155,3 +156,125 @@ def print_program_with_distributed_attr(program, dist_context=None): print(program) set_default_distributed_context(original_default_context) lock.release() + + +def _get_comm_group(processes, shape, axis, rank): + """ + Given a rank and the processes mesh the rank belongs to, + compute the communication peers of the rank based on the give axis in the mesh. + + Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. + the rank communication peers of rank 0 (included) are following: + in axis 0: [0, 1] + in axis 1: [0, 2] + in axis 2: [0, 4] + in axis 3: [0, 8] + """ + + # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous + # tricks to support processes mesh when it is not start with 0 or continuous + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(shape, rank_relatvie) + coordinates_in_group = [coordinate[:] for i in range(shape[axis])] + + # select comm group + for i in range(shape[axis]): + coordinates_in_group[i][axis] = i + + ranks_in_group_relative = [ + _coordinate2linear_idx(shape, coordinate) + for coordinate in coordinates_in_group + ] + ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] + + return sorted(ranks_in_group) + + +def _coordinate2linear_idx(mesh_shape, coordinate): + """ + convert a coordinate in multidimensional mesh space into a scala idx in linear space. + + it use Row-major order for dimension conversion. + so it has: [most_significant_dim, ..., least_significant_dim] + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + linear_idx of a n dimensional coordinate is: + + I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + + I[n-3] * ( S[n-4] * .... S[0]) + + ... + I[1] * ( S[0]) + + I[0] + + """ + # NOTE the following function work based on a strong an assumption + # that the processes in mesh are + # 1. starts from 0 + # 2. continuous + # it will be wrong if ths above condition doesnot meet, + # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} + # if you want a more general mapping, you should use cartesian product + + assert len(mesh_shape) == len( + coordinate + ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( + mesh_shape, coordinate) + for i in range(len(mesh_shape)): + assert coordinate[ + i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( + i, coordinate) + assert coordinate[i] < mesh_shape[ + i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( + i, mesh_shape, coordinate) + + base = mesh_shape[-1] + linear_idx = coordinate[-1] + + # row major order + for i in range(len(mesh_shape) - 2, -1, -1): + linear_idx += base * coordinate[i] + base *= mesh_shape[i] + + return linear_idx + + +def _linear_idx2coordinate(mesh_shape, linear_idx): + """ + mapping a linear scala into multidimensional mesh space, return it coordinate in that space. + + it is the inverse function of _coordinate2linear_idx. + assume: + + the size of i-th dimension to be: S[i] + the index of j-th dimension is: I[j] + + the coordinate given linear_idx is: + + I[0] = linear_idx % S[0] + I[0] = (linear_idx / S[0]) % S[1] + I[0] = (linear_idx / (S[0] * S[1])) % S[2] + .... + + """ + + assert linear_idx >= 0, "linear index [{}] is least than zero".format( + linear_idx) + assert linear_idx < np.prod( + mesh_shape + ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( + mesh_shape, linear_idx) + + base = 1 + coordinate = [-1] * len(mesh_shape) + + for i in reversed(range(len(mesh_shape))): + offset = linear_idx / base + coordinate[i] = int(offset % mesh_shape[i]) + base *= mesh_shape[i] + + # row major order + return coordinate diff --git a/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py index 5fbec7da0b5ed..9d099a2af24fa 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py @@ -124,7 +124,7 @@ def _init_process_group(self): collective_helper._init_communicator( self.startup_program, self.current_endpoint, self.mp_endpoints, self.mp_rank, self.mp_ring_id, True, self.global_ring_id, True) - #self._broadcast_params(self.mp_ring_id, mp_mode=True) + self._broadcast_params(self.mp_ring_id, mp_mode=True) # Create dp rings if self.nranks > self.mp_degree: diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index fc7b39ede244d..706d64d8d35b6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -42,11 +42,13 @@ def __init__(self, layers, hcg, strategy): self.accumulate_steps = self._strategy.pipeline_configs[ 'accumulate_steps'] + self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] + self.num_stages = self._hcg.get_pipe_parallel_world_size() self.stage_id = self._hcg.get_stage_id() self.pp_group = self._hcg.get_pipe_parallel_group() - p2p.initialize_p2p_groups(hcg) + p2p.initialize_p2p_groups(hcg, self._using_cache) _initialize_recompute_hcg(hcg) @@ -55,6 +57,8 @@ def __init__(self, layers, hcg, strategy): self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 + self._compute_loss = True + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -85,6 +89,7 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): self.lr_scheduler = lr_scheduler self.scaler = scaler self.data = data + self._compute_loss = True self._layers.train() @@ -151,12 +156,57 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): self._layers.allreduce_shared_weight_gradients() - self.train_loss = self._reduce_final_loss() + self.train_loss = self._broadcast_final_loss() # optimizer self._optimizer_step() return self.train_loss + def eval_batch(self, data, compute_loss=False): + self._layers.eval() + self._compute_loss = compute_loss + + # save data for eval + self.data = data + # store data id for micro_batch + self.micro_batch_id = 0 + + # store total loss of entire batch + self.total_loss = None + + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps + + input_buffers = [] + output_buffers = [] + + for step_id in range(startup_steps): + input_tensor = p2p.recv_forward() + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if steady_steps > 0: + input_tensor = p2p.recv_forward() + + for i in range(steady_steps): + last_iter = (i == (steady_steps - 1)) + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if not last_iter: + input_tensor = p2p.recv_forward() + + return self.total_loss if self._compute_loss else output_buffers + def _forward_step(self, input_tensor): if self.stage_id == 0: input_tensor = self._load_micro_batch(self.micro_batch_id) @@ -164,18 +214,21 @@ def _forward_step(self, input_tensor): output_tensor = self._layers.forward(input_tensor) if self.is_last_stage: - labels = self._load_micro_batch(self.micro_batch_id) - output_tensor = self._layers._loss_fn(output_tensor, labels) - assert isinstance( - output_tensor, paddle. - Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype" - - if self.accumulate_steps > 1: - output_tensor = output_tensor / self.accumulate_steps - - if self.total_loss is None: - self.total_loss = paddle.zeros_like(output_tensor) - self.total_loss += output_tensor.detach() + # train calculate loss for train + if self._compute_loss: + assert self._layers._loss_fn is not None, "loss function should exist to compute loss" + labels = self._load_micro_batch(self.micro_batch_id) + output_tensor = self._layers._loss_fn(output_tensor, labels) + assert isinstance( + output_tensor, paddle.Tensor + ), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() self.micro_batch_id += 1 return output_tensor @@ -245,7 +298,7 @@ def _load_micro_batch(self, cache_id): # No data input is required for other stages inputs = None - def _reduce_final_loss(self): + def _broadcast_final_loss(self): if self.is_last_stage: assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" loss = self.total_loss.detach() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index c508c88015cfd..e2c99edac1270 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -19,11 +19,13 @@ from paddle import _C_ops _hcg = None +_use_cache = False -def initialize_p2p_groups(hcg): - global _hcg +def initialize_p2p_groups(hcg, use_cache=True): + global _hcg, _use_cache _hcg = hcg + _use_cache = use_cache send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( ) @@ -372,7 +374,7 @@ def recv_forward(): else: if not _send_recv_meta.has_recv_meta: _send_recv_meta.recv_meta(_hcg.recv_prev_group) - _send_recv_meta.has_recv_meta = True + _send_recv_meta.has_recv_meta = _use_cache input_tensor, _ = _p2p_helper( tensor_send_next=None, @@ -399,7 +401,7 @@ def send_forward(output_tensor): if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) - _send_recv_meta.has_send_meta = True + _send_recv_meta.has_send_meta = _use_cache _p2p_helper( tensor_send_next=output_tensor, diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 8bb4d82b72478..3fe7f90a5b3fd 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -249,6 +249,7 @@ def __bootstrap__(): 'npu_config_path', 'get_host_by_name_time', 'hccl_check_nan', + 'min_loss_scaling', ] core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)]) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 6208b43c9e9e4..790eff04c3648 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -202,7 +202,7 @@ def forward(self, inputs): self._quantize_inputs = ImperativeQuantizeInputs(**kwargs) - self._quantize_outputs = ImperativeQuantizeOutputs() + self._quantize_outputs = ImperativeQuantizeOutputs(moving_rate) def quantize(self, model): """ @@ -413,6 +413,8 @@ def apply(self, model): "The model must be the instance of dygraph.Layer." for cur_name, cur_layer in model.named_sublayers(): + if '_act_preprocess' in cur_name: + continue if not self._is_target_layer(cur_layer): continue diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index c2d7a9bb4d517..01b54f8f13a9f 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -133,6 +133,7 @@ "pad2d": [["X"], ["Out"]], "flatten": [["X"], ["Out"]], "flatten2": [["X"], ["Out"]], + "unsqueeze2": [["X"], ["Out"]], } _conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] diff --git a/python/paddle/fluid/distributed/ps_instance.py b/python/paddle/fluid/distributed/ps_instance.py index 42033a0ada4ac..9254a4a136f77 100644 --- a/python/paddle/fluid/distributed/ps_instance.py +++ b/python/paddle/fluid/distributed/ps_instance.py @@ -156,5 +156,5 @@ def finalize(self): if __name__ == "__main__": - instance = PaddlePSInstance(1, 1, 2, 50) + instance = PaddlePSInstance(1, 2) instance.barrier_all() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 913b7cec60227..ffcc8c95bbc81 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -16,6 +16,8 @@ import six import sys import traceback +import linecache +import re from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map @@ -29,6 +31,9 @@ DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR" DEFAULT_DISABLE_NEW_ERROR = 0 +SOURCE_CODE_RANGE = 5 +BLANK_COUNT_BEFORE_FILE_STR = 4 + def attach_error_data(error, in_runtime=False): """ @@ -40,6 +45,7 @@ def attach_error_data(error, in_runtime=False): Returns: An error attached data about original source code information and traceback. """ + e_type, e_value, e_traceback = sys.exc_info() tb = traceback.extract_tb(e_traceback)[1:] @@ -82,12 +88,61 @@ def __init__(self, location, function_name, source_code): def formated_message(self): # self.source_code may be empty in some functions. # For example, decorator generated function - return ' File "{}", line {}, in {}\n\t{}'.format( + return ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format( self.location.filepath, self.location.lineno, self.function_name, self.source_code.lstrip() if isinstance(self.source_code, str) else self.source_code) +class TraceBackFrameRange(OriginInfo): + """ + Traceback frame information. + """ + + def __init__(self, location, function_name): + self.location = location + self.function_name = function_name + self.source_code = [] + blank_count = [] + begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) + + for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE): + line = linecache.getline(self.location.filepath, i).rstrip('\n') + line_lstrip = line.lstrip() + self.source_code.append(line_lstrip) + if not line_lstrip: # empty line from source code + blank_count.append(-1) + else: + blank_count.append(len(line) - len(line_lstrip)) + + if i == self.location.lineno: + hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' + self.source_code.append(hint_msg) + blank_count.append(blank_count[-1]) + linecache.clearcache() + # remove top and bottom empty line in source code + while len(self.source_code) > 0 and not self.source_code[0]: + self.source_code.pop(0) + blank_count.pop(0) + while len(self.source_code) > 0 and not self.source_code[-1]: + self.source_code.pop(-1) + blank_count.pop(-1) + + min_black_count = min([i for i in blank_count if i >= 0]) + for i in range(len(self.source_code)): + # if source_code[i] is empty line between two code line, dont add blank + if self.source_code[i]: + self.source_code[i] = ' ' * (blank_count[i] - min_black_count + + BLANK_COUNT_BEFORE_FILE_STR * 2 + ) + self.source_code[i] + + def formated_message(self): + msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format( + self.location.filepath, self.location.lineno, self.function_name) + # add empty line after range code + return msg + '\n'.join(self.source_code) + '\n' + + class ErrorData(object): """ Error data attached to an exception which is raised in un-transformed code. @@ -128,26 +183,34 @@ def create_message(self): return '\n'.join(message_lines) # Step2: Optimizes stack information with source code information of dygraph from user. - for filepath, lineno, funcname, code in self.origin_traceback: + whether_source_range = True + for filepath, lineno, funcname, code in self.origin_traceback[::-1]: loc = Location(filepath, lineno) - dygraph_func_info = self.origin_info_map.get(loc.line_location, None) if dygraph_func_info: - # TODO(liym27): more information to prompt users that this is the original information. - # Replaces trace stack information about transformed static code with original dygraph code. - traceback_frame = self.origin_info_map[loc.line_location] - else: - traceback_frame = TraceBackFrame(loc, funcname, code) - - message_lines.append(traceback_frame.formated_message()) + if whether_source_range: + traceback_frame = TraceBackFrameRange( + dygraph_func_info.location, + dygraph_func_info.function_name) + whether_source_range = False + else: + traceback_frame = TraceBackFrame( + dygraph_func_info.location, + dygraph_func_info.function_name, + dygraph_func_info.source_code) + # Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2 + message_lines.insert(2, traceback_frame.formated_message()) # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length # is gather than 1, for example, the error_type is IndentationError. format_exception = traceback.format_exception_only(self.error_type, self.error_value) - error_message = [" " * 4 + line for line in format_exception] + error_message = [ + " " * BLANK_COUNT_BEFORE_FILE_STR + line + for line in format_exception + ] message_lines.extend(error_message) return '\n'.join(message_lines) @@ -162,6 +225,7 @@ def _simplify_error_value(self): 1. Need a more robust way because the code of start_trace may change. 2. Set the switch to determine whether to simplify error_value """ + assert self.in_runtime is True error_value_lines = str(self.error_value).split("\n") @@ -169,13 +233,46 @@ def _simplify_error_value(self): start_trace = "outputs = static_func(*inputs)" start_idx = error_value_lines_strip.index(start_trace) + error_value_lines = error_value_lines[start_idx + 1:] + error_value_lines_strip = error_value_lines_strip[start_idx + 1:] + + # use empty line to locate the bottom_error_message + empty_line_idx = error_value_lines_strip.index('') + bottom_error_message = error_value_lines[empty_line_idx + 1:] + + filepath = '' + error_from_user_code = [] + pattern = 'File "(?P.+)", line (?P.+), in (?P.+)' + for i in range(0, len(error_value_lines_strip), 2): + if error_value_lines_strip[i].startswith("File "): + re_result = re.search(pattern, error_value_lines_strip[i]) + tmp_filepath, lineno_str, function_name = re_result.groups() + code = error_value_lines_strip[i + 1] if i + 1 < len( + error_value_lines_strip) else '' + if i == 0: + filepath = tmp_filepath + if tmp_filepath == filepath: + error_from_user_code.append( + (tmp_filepath, int(lineno_str), function_name, code)) + + error_frame = [] + whether_source_range = True + for filepath, lineno, funcname, code in error_from_user_code[::-1]: + loc = Location(filepath, lineno) + if whether_source_range: + traceback_frame = TraceBackFrameRange(loc, funcname) + whether_source_range = False + else: + traceback_frame = TraceBackFrame(loc, funcname, code) + + error_frame.insert(0, traceback_frame.formated_message()) - error_value_str = '\n'.join(error_value_lines) + error_frame.extend(bottom_error_message) + error_value_str = '\n'.join(error_frame) self.error_value = self.error_type(error_value_str) def raise_new_exception(self): - # Raises the origin error if disable dygraph2static error module, if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)): raise diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py index 5f48ba6b2a725..74ded7c09967f 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/ps_dispatcher.py @@ -66,7 +66,7 @@ class HashName(PSDispatcher): """ def __init__(self, pserver_endpoints): - super(self.__class__, self).__init__(pserver_endpoints) + super(HashName, self).__init__(pserver_endpoints) def _hash_block(self, block_str, total): return hash(block_str) % total @@ -106,7 +106,7 @@ class RoundRobin(PSDispatcher): """ def __init__(self, pserver_endpoints): - super(self.__class__, self).__init__(pserver_endpoints) + super(RoundRobin, self).__init__(pserver_endpoints) def dispatch(self, varlist): """ diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index b2735727f6755..9246b8e44840c 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -382,6 +382,7 @@ def get_trainer_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) + idx = 0 if not self.is_geo_mode(): for merged in self.merged_dense_pairs: @@ -401,9 +402,10 @@ def get_trainer_send_context(self): ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx + idx += 1 if self.is_async_mode(): - name, ctx = self._step_ctx() + name, ctx = self._step_ctx(idx) send_ctx[name] = ctx else: for pairs in self.origin_sparse_pairs: @@ -427,7 +429,8 @@ def get_trainer_send_context(self): param_ctx.is_distributed()) send_ctx[ctx.var_name()] = ctx - name, ctx = self._step_ctx() + idx += 1 + name, ctx = self._step_ctx(idx) send_ctx[name] = ctx return send_ctx @@ -435,6 +438,7 @@ def get_communicator_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) + idx = 0 if self.is_geo_mode(): for pairs in self.merged_dense_pairs: @@ -451,7 +455,8 @@ def get_communicator_send_context(self): ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx - name, ctx = self._step_ctx() + idx += 1 + name, ctx = self._step_ctx(idx) send_ctx[name] = ctx else: for merged in self.merged_dense_pairs: @@ -469,8 +474,9 @@ def get_communicator_send_context(self): ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx + idx += 1 - name, ctx = self._step_ctx() + name, ctx = self._step_ctx(idx) send_ctx[name] = ctx return send_ctx diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py index 9a21a5a850db9..e2fb29c5439e1 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py @@ -412,11 +412,13 @@ def _minimize(self, sparse_table_index = 0 for num in range(len(losses)): loss = losses[num] + parameters = None + if parameter_list != None: + parameters = parameter_list[num] prog_id = str(id(loss.block.program)) # param_grads of program params_grads = sorted( - fluid.backward.append_backward(loss, parameter_list, - no_grad_set), + fluid.backward.append_backward(loss, parameters, no_grad_set), key=lambda x: x[0].name) flag_use_ps_gpu = strategy.get("use_ps_gpu", False) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 59dfec005d852..4216384b6f8b2 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10977,6 +10977,7 @@ def slice(input, axes, starts, ends): ends_tensor = None if isinstance(axes, (list, tuple)): + axes = list(axes) if len(axes) == 0: raise ValueError( "Input axes should not be an empty list/tuple.") diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index eb3d559ddcde9..fc5c30684b279 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5323,7 +5323,13 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size): "copy_data": False, "use_align": True, "dtype": grads[0].dtype, - self._op_role_key: self._op_role.Backward + self._op_role_key: self._op_role.Backward, + # On npu, the nan/inf check login is different with gpu. + # If there are some not initialized sections in the fused var, + # and the value in those sections are nan/inf, it will trigger the nan/inf check. + # To avoid these problematic triggers, set constant is needed for npu + "set_constant": core.is_compiled_with_npu(), + "constant": float(0.0), }) offset += 1 # For the gradient_merged_fused_var, given a init value during the coalesce op diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fb7f18fcc4ef7..2c001614d1bac 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -79,6 +79,8 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) list(APPEND MIXED_DIST_TEST_OPS test_fleet_distributed_strategy) list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) list(APPEND MIXED_DIST_TEST_OPS test_fleet_static_mp_layers) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -206,6 +208,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) list(REMOVE_ITEM TEST_OPS test_parallel_class_center_sample) LIST(REMOVE_ITEM TEST_OPS test_parallel_margin_cross_entropy) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -854,6 +858,7 @@ set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static PROPER set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_svd_op PROPERTIES TIMEOUT 120) set_tests_properties(test_deformable_psroi_pooling PROPERTIES TIMEOUT 120) set_tests_properties(test_trilinear_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_static_runner_mnist PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index c177b556b8665..6dd8c8e0766bf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -98,6 +98,16 @@ def test_func(self): return +@paddle.jit.to_static +def func_error_in_runtime_with_empty_line(x): + x = fluid.dygraph.to_variable(x) + two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") + + x = fluid.layers.reshape(x, shape=[1, two]) + + return x + + class TestFlags(unittest.TestCase): def setUp(self): self.reset_flags_to_default() @@ -218,7 +228,10 @@ def set_message(self): ['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath), 'inner_func()', 'File "{}", line 28, in inner_func'.format(self.filepath), + 'def inner_func():', 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + '<--- HERE', + 'return', ] def set_func_call(self): @@ -242,7 +255,11 @@ def set_message(self): self.expected_message = \ [ 'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath), - 'x = fluid.layers.reshape(x, shape=[1, 2])' + 'def func_error_in_compile_time_2(x):', + 'x = fluid.dygraph.to_variable(x)', + 'x = fluid.layers.reshape(x, shape=[1, 2])', + '<--- HERE', + 'return x' ] @@ -261,7 +278,10 @@ def set_exception_type(self): def set_message(self): self.expected_message = \ ['File "{}", line 91, in forward'.format(self.filepath), + '@paddle.jit.to_static', + 'def forward(self):', 'self.test_func()', + '<--- HERE' ] def set_func_call(self): @@ -283,7 +303,26 @@ def set_message(self): self.expected_message = \ [ 'File "{}", line 54, in func_error_in_runtime'.format(self.filepath), - 'x = fluid.layers.reshape(x, shape=[1, two])' + 'x = fluid.dygraph.to_variable(x)', + 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', + 'x = fluid.layers.reshape(x, shape=[1, two])', + '<--- HERE', + 'return x' + ] + + +class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime): + def set_func(self): + self.func = func_error_in_runtime_with_empty_line + + def set_message(self): + self.expected_message = \ + [ + 'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(self.filepath), + 'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")', + 'x = fluid.layers.reshape(x, shape=[1, two])', + '<--- HERE', + 'return x' ] @@ -318,7 +357,12 @@ def set_exception_type(self): def set_message(self): self.expected_message = \ ['File "{}", line 80, in forward'.format(self.filepath), - 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + 'def forward(self, x):', + 'y = self._linear(x)', + 'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + '<--- HERE', + 'out = fluid.layers.mean(y[z])', + 'return out' ] def set_func_call(self): @@ -329,7 +373,7 @@ def test_error(self): self._test_raise_new_exception() -# Situation 4: NotImplementedError +# # Situation 4: NotImplementedError class TestErrorInOther(unittest.TestCase): def test(self): paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index 524099c6ab05e..c4c1e565068b2 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -177,10 +177,13 @@ def test_pp_model(self): x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) x = paddle.to_tensor(x_data) x.stop_gradient = True + + e_loss = model.eval_batch([x, x], True) loss = model.train_batch([x, x], optimizer, scheduler) - # TODO(shenliang03) add utest for loss - print("loss: ", loss) + # TODO(shenliang03) add utest for loss + if pp_id != 0: + np.testing.assert_allclose(loss.numpy(), e_loss.numpy()) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index b60c3f77e0c28..b59fcd8d02e2f 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -66,7 +66,7 @@ def test_interp_base(self): def check_cost_info(self, cost_info): if core.is_compiled_with_cuda(): - self.assertEqual(cost_info.host_memory_bytes(), 16) + # self.assertEqual(cost_info.host_memory_bytes(), 16) self.assertGreater(cost_info.device_memory_bytes(), 0) self.assertGreaterEqual(cost_info.device_total_memory_bytes(), cost_info.device_memory_bytes()) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py new file mode 100644 index 0000000000000..caebcffd0e966 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_slice_mkldnn_op.py @@ -0,0 +1,199 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle + + +@OpTestTool.skip_if(core.is_compiled_with_cuda(), + "CUDA required dygraph so oneDNN UT must be skipped") +class TestSliceOneDNNOp(OpTest): + def setUp(self): + self.op_type = "slice" + self.config() + self.set_inputs() + self.outputs = {'Out': self.out} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'infer_flags': self.infer_flags, + 'use_mkldnn': True + } + self.set_attrs() + + def set_inputs(self): + self.inputs = {'Input': self.input} + + def set_attrs(self): + pass + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.infer_flags = [1, 1, 1] + self.out = self.input[1:3, 0:3, 2:4, :] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Input'], 'Out') + + +class TestSliceOneDNNOp1(TestSliceOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 2] + self.infer_flags = [1, 1, 1] + self.out = self.input[-3:3, 0:100, 2:-1, :] + + +class TestSliceOneDNNOp2(TestSliceOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [-3, 0, 2] + self.ends = [3, 100, -1] + self.axes = [0, 1, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[-3:3, 0:100, :, 2:-1] + + +class TestSliceDecrease1AxisOneDNNOp(TestSliceOneDNNOp): + def set_attrs(self): + self.attrs['decrease_axis'] = self.decrease_axis + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0:3, 2:4, :] + + +class TestSliceDecrease2AxesOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [1, 0, 2] + self.ends = [2, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[1, 0, 2:4, :] + + +class TestSliceDecrease3AxesOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float32") + self.starts = [-1, 0, 2] + self.ends = [1000000, 1, 4] + self.axes = [0, 1, 2] + self.decrease_axis = [0, 1] + self.infer_flags = [1, 1, 1] + self.out = self.input[-1, 0, 2:4, :] + + +class TestSliceDecrease4AxesOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 7]).astype("float32") + self.starts = [0, 1, 2, 3] + self.ends = [1, 2, 3, 4] + self.axes = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2, 3] + self.infer_flags = [1, 1, 1] + self.out = self.input[0, 1, 2, 3:4] + + +class TestSlice5DOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([3, 4, 5, 6, 7]).astype("float32") + self.starts = [-1] + self.ends = [1000000] + self.axes = [4] + self.decrease_axis = [4] + self.infer_flags = [1, 1, 1] + self.out = self.input[:, :, :, :, -1] + + +class TestSlice3DOneDNNOp(TestSliceDecrease1AxisOneDNNOp): + def config(self): + self.input = np.random.random([5, 4, 5]).astype("float32") + self.starts = [-1] + self.ends = [1000000] + self.axes = [2] + self.decrease_axis = [2] + self.infer_flags = [1, 1, 1] + self.out = self.input[:, :, -1] + + +# BF16 TESTS +def create_bf16_test_class(parent): + @OpTestTool.skip_if_not_cpu_bf16() + class TestSliceBF16OneDNNOp(parent): + def set_inputs(self): + self.dtype = np.uint16 + self.inputs = {'Input': convert_float_to_uint16(self.input)} + + def calculate_grads(self): + self.dout = self.out + self.dx = np.zeros(shape=self.input.shape) + + begin = [None] * self.input.ndim + end = [None] * self.input.ndim + + for i in range(len(self.axes)): + begin[self.axes[i]] = self.starts[i] + end[self.axes[i]] = self.ends[i] + self.dx[begin[0]:end[0], begin[1]:end[1], begin[2]:end[2], begin[3]: + end[3]] = self.dout + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["Input"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + cls_name = "{0}_{1}".format(parent.__name__, "BF16") + TestSliceBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestSliceBF16OneDNNOp + + +create_bf16_test_class(TestSliceOneDNNOp) +create_bf16_test_class(TestSliceOneDNNOp1) +create_bf16_test_class(TestSliceDecrease1AxisOneDNNOp) +create_bf16_test_class(TestSliceDecrease2AxesOneDNNOp) +create_bf16_test_class(TestSliceDecrease3AxesOneDNNOp) +create_bf16_test_class(TestSliceDecrease4AxesOneDNNOp) + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_transpose_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_transpose_op_npu.py new file mode 100644 index 0000000000000..8cb94cb98f1e5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_transpose_op_npu.py @@ -0,0 +1,698 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test import OpTest, skip_check_grad_ci + +from test_conv2d_transpose_op import conv2dtranspose_forward_naive + +paddle.enable_static() + + +@skip_check_grad_ci( + reason='''Inference only, it doesn't need to call check_grad.''') +class TestConv2DTransposeOp(OpTest): + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def init_dtype(self): + self.dtype = np.float16 + + def init_data_format(self): + self.data_format = "NCHW" + + def setUp(self): + self.init_op_type() + self.init_dtype() + self.set_npu() + self.init_data_format() + self.output_padding = [] + self.pad = [0, 0] + self.padding_algorithm = "EXPLICIT" + self.init_test_case() + self.output_size = None + + input_ = np.random.random(self.input_size).astype(self.dtype) + filter_ = np.random.random(self.filter_size).astype(self.dtype) + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': False, + 'is_test': False, + 'use_mkldnn': False, + 'data_format': self.data_format + } + if self.output_size is not None: + self.attrs['output_size'] = self.output_size + + if len(self.output_padding) > 0: + self.attrs['output_padding'] = self.output_padding + output = conv2dtranspose_forward_naive(input_, filter_, + self.attrs).astype(self.dtype) + + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-2) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose" + + +class TestWithSymmetricPad_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithSymmetricPad(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithAsymmetricPad_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithAsymmetricPad(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithSAMEPad_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [2, 1] + self.dilations = [1, 2] + self.groups = 1 + self.input_size = [2, 3, 6, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 4, 3] + self.padding_algorithm = 'SAME' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithSAMEPad(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [2, 1] + self.dilations = [1, 2] + self.groups = 1 + self.input_size = [2, 3, 6, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 4, 3] + self.padding_algorithm = 'SAME' + + +class TestWithVALIDPad_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'VALID' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithVALIDPad(TestConv2DTransposeOp): + def init_test_case(self): + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + self.padding_algorithm = 'VALID' + + +class TestWithGroups_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithGroups(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + +class TestWithStride_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithStride(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithDilation_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithDilation(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + +class TestWithEvenUpsample_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithEvenUpsample(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + +class TestWithEvenUpsampleOutputPadding_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithEvenUpsampleOutputPadding(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 3, 7, 7] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 5, 5] + + +class Test_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class Test_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithSymmetricPad_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithSymmetricPad_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithAsymmetricPad_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithAsymmetricPad_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 0, 1, 2] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithGroups_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithGroups_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + +class TestWithStride_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithStride_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [2, 5, 5, 3] # NCHW + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithDilation_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithDilation_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.groups = 1 + self.dilations = [2, 2] + self.input_size = [2, 5, 5, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 3, 3] + self.data_format = 'NHWC' + + +class TestWithEvenUpsample_NHWC_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithEvenUpsample_NHWC(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_size = [14, 14] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + +class TestWithEvenUpsample_NHWC_output_padding_FP32(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + def init_dtype(self): + self.dtype = np.float32 + + +class TestWithEvenUpsample_NHWC_output_padding(TestConv2DTransposeOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [2, 2] + self.groups = 1 + self.dilations = [1, 1] + self.output_padding = [1, 1] + self.input_size = [2, 7, 7, 3] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 6, 5, 5] + self.data_format = 'NHWC' + + +class TestConv2DTransposeAPI(unittest.TestCase): + def test_case1(self): + data1 = fluid.layers.data( + name='data1', shape=[3, 5, 5], dtype='float32') + data2 = fluid.layers.data( + name='data2', shape=[5, 5, 3], dtype='float32') + out1 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + data_format='NCHW') + out2 = fluid.layers.conv2d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + data_format='NHWC') + out3 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding=[[0, 0], [1, 1], [1, 1], [0, 0]], + data_format='NHWC') + out4 = fluid.layers.conv2d_transpose( + input=data1, + groups=3, + num_filters=6, + filter_size=3, + padding=[[0, 0], [0, 0], [2, 1], [0, 0]], + data_format='NCHW') + out5 = fluid.layers.conv2d_transpose( + input=data2, + groups=1, + num_filters=6, + filter_size=3, + padding='SAME', + data_format='NCHW') + out6 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + filter_size=3, + padding='VALID', + data_format='NHWC') + out7 = fluid.layers.conv2d_transpose( + input=data1, + groups=1, + num_filters=6, + output_size=[7, 7], + padding=[0, 0], + data_format='NHWC') + + data1_np = np.random.random((2, 3, 5, 5)).astype("float32") + data2_np = np.random.random((2, 5, 5, 3)).astype("float32") + + place = core.NPUPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run( + fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2, out3, out4, out5, out6, out7], + return_numpy=True) + self.assertIsNotNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[2]) + self.assertIsNotNone(results[3]) + self.assertIsNotNone(results[4]) + self.assertIsNotNone(results[5]) + self.assertIsNotNone(results[6]) + + +class TestConv2DTransposeOpException(unittest.TestCase): + def test_exception(self): + data = fluid.layers.data(name='data', shape=[3, 5, 5], dtype="float32") + + def attr_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + data_format="NCDHW") + + self.assertRaises(ValueError, attr_data_format) + + def attr_padding_str(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding='Vald') + + self.assertRaises(ValueError, attr_padding_str) + + def attr_padding_list(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [1, 1], [0, 0], [0, 0]]) + + self.assertRaises(ValueError, attr_padding_list) + + def attr_padding_with_data_format(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=1, + num_filters=6, + filter_size=3, + padding=[[1, 1], [0, 0], [0, 0], [1, 1]], + data_format='NHWC') + + self.assertRaises(ValueError, attr_padding_with_data_format) + + error_input = fluid.layers.data( + name='error_data', shape=[1], dtype="float32") + + def error_input_size(): + out = fluid.layers.conv2d_transpose( + input=error_input, groups=1, num_filters=6, filter_size=3) + + self.assertRaises(ValueError, error_input_size) + + def error_groups(): + out = fluid.layers.conv2d_transpose( + input=data, + groups=0, + num_filters=6, + filter_size=3, + data_format='NHWC') + + self.assertRaises(ValueError, error_groups) + + +class TestConv2DTransposeRepr(unittest.TestCase): + def test_case(self): + paddle.disable_static(paddle.NPUPlace(0)) + x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) + conv = nn.Conv2DTranspose(4, 6, (3, 3), output_padding=1, stride=2) + print(conv) + y_var = conv(x_var) + y_np = y_var.numpy() + self.assertIsNotNone(y_np) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py index dea1828a6d75f..ce645f317d054 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_pow_op_npu.py @@ -13,19 +13,71 @@ # limitations under the License. from __future__ import print_function +import paddle.fluid as fluid +import paddle +from op_test import OpTest import numpy as np import unittest import sys sys.path.append("..") -from op_test import OpTest -import paddle -import paddle.fluid as fluid paddle.enable_static() SEED = 2021 +def ComputeGrad(x, y, out, axis): + grad = 1 / out.size + shape_x = x.shape + shape_y = y.shape + shape_out = out.shape + reduce_axes_x = [] + reduce_axes_y = [] + + if shape_x != shape_out: + if len(shape_x) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_x)) or ( + shape_out[ax] > 1 and shape_x[ax - src_axis] == 1): + reduce_axes_x.append(ax) + + if shape_y != shape_out: + if len(shape_y) < len(shape_out): + src_axis = axis + else: + src_axis = 0 + + for ax in range(len(shape_out)): + if (ax < src_axis or ax >= src_axis + len(shape_y)) or ( + shape_out[ax] > 1 and shape_y[ax - src_axis] == 1): + reduce_axes_y.append(ax) + + if len(reduce_axes_x) > 0: + for i in reduce_axes_x: + x = np.expand_dims(x, axis=i) + + if len(reduce_axes_y) > 0: + for i in reduce_axes_y: + y = np.expand_dims(y, axis=i) + + dx = y * np.power(x, y - 1) * grad + dy = np.log(x) * np.power(x, y) * grad + + if len(reduce_axes_x) > 0: + for i, element in enumerate(reduce_axes_x): + dx = np.add.reduce(dx, element - i) + + if len(reduce_axes_y) > 0: + for i, element in enumerate(reduce_axes_y): + dy = np.add.reduce(dy, element - i) + + return dx, dy + + class TestElementwisePow(OpTest): def setUp(self): self.set_npu() @@ -33,17 +85,15 @@ def setUp(self): self.place = paddle.NPUPlace(0) self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - out = np.power(x, y) + self.init_input_output() + self.init_axis() self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.attrs = {} - self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} def set_npu(self): self.__class__.use_npu = True @@ -54,44 +104,177 @@ def init_dtype(self): def test_check_output(self): self.check_output_with_place(self.place) - # TODO(ascendrc): Pow grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + def init_axis(self): + self.axis = -1 + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowFp16(TestElementwisePow): + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) -class TestElementwisePowFp16(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "elementwise_pow" - self.place = paddle.NPUPlace(0) + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True - self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - out = np.power(x, y) + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) - } - self.attrs = {} - self.outputs = {'Out': out} + +class TestElementwisePowDouble(TestElementwisePow): + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) def set_npu(self): self.__class__.use_npu = True self.__class__.no_need_check_grad = True def init_dtype(self): - self.dtype = np.float16 + self.dtype = np.float64 def test_check_output(self): self.check_output_with_place(self.place, atol=1e-5) +class TestElementwisePowOp_broadcast_0(TestElementwisePow): + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [1, 11, 17]).astype(self.dtype) + self.out = np.power(self.x, self.y) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_1(TestElementwisePow): + def init_axis(self): + self.axis = 1 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(1, 2, [2, 100, 1]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(1, 100, 1)) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + +class TestElementwisePowOp_broadcast_2(TestElementwisePow): + def init_axis(self): + self.axis = 0 + + def init_input_output(self): + np.random.seed(SEED) + self.x = np.random.uniform(0.1, 1, [100, 3, 1]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype) + self.out = np.power(self.x, self.y.reshape(100, 1, 1)) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy]) + + def test_check_grad_ingore_x(self): + _, dy = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[dy]) + + def test_check_grad_ingore_y(self): + dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis) + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set("Y"), + user_defined_grads=[dx]) + + class TestElementwisePowNet(unittest.TestCase): def _test(self, run_npu=True): main_prog = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py new file mode 100644 index 0000000000000..b124a54624171 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_gather_nd_op_npu.py @@ -0,0 +1,289 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +import paddle + + +def gather_nd_grad(x, index): + dout_shape = index.shape[:-1] + x.shape[index.shape[-1]:] + numel = 1 + for i in dout_shape: + numel = numel * i + dout = np.full(dout_shape, 1. / numel) + dx = np.full_like(x, 0) + + index = tuple(index.reshape(-1, index.shape[-1]).T) + np.add.at(dx, index, dout) + + return dx + + +def test_class1(op_type, typename): + class TestGatherNdOpWithEmptyIndex(OpTest): + #Index has empty element, which means copy entire tensor + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.random((5, 20)).astype(typename) + self.inputs = { + 'X': xnp, + 'Index': np.array([[], []]).astype("int32") + } + self.outputs = { + 'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) + } + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_1".format(op_type, typename) + TestGatherNdOpWithEmptyIndex.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithEmptyIndex + + +def test_class2(op_type, typename): + class TestGatherNdOpWithIndex1(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.random((5, 20)).astype(typename) + self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_2".format(op_type, typename) + TestGatherNdOpWithIndex1.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithIndex1 + + +def test_class3(op_type, typename): + class TestGatherNdOpWithLowIndex(OpTest): + #Index has low rank, X has high rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([[1], [2]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': xnp[tuple(index.T)]} + self.x_grad = gather_nd_grad(xnp, index) + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place( + self.place, ['X'], 'Out', user_defined_grads=[self.x_grad]) + + cls_name = "{0}_{1}_3".format(op_type, typename) + TestGatherNdOpWithLowIndex.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithLowIndex + + +def test_class4(op_type, typename): + class TestGatherNdOpIndex1(OpTest): + #Index has low rank, X has high rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([1, 2]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + + self.outputs = {'Out': xnp[tuple(index.T)]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_4".format(op_type, typename) + TestGatherNdOpIndex1.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpIndex1 + + +def test_class5(op_type, typename): + class TestGatherNdOpWithSameIndexAsX(OpTest): + #Index has same rank as X's rank + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + xnp = np.random.uniform(0, 100, (10, 10)).astype(typename) + index = np.array([[1, 1], [2, 1]]).astype("int64") + + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22] + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_5".format(op_type, typename) + TestGatherNdOpWithSameIndexAsX.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithSameIndexAsX + + +def test_class6(op_type, typename): + class TestGatherNdOpWithHighRankSame(OpTest): + #Both Index and X have high rank, and Rank(Index) = Rank(X) + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + shape = (5, 2, 3, 1, 10) + xnp = np.random.rand(*shape).astype(typename) + index = np.vstack([np.random.randint( + 0, s, size=2) for s in shape]).T + + self.inputs = {'X': xnp, 'Index': index.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)]} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_6".format(op_type, typename) + TestGatherNdOpWithHighRankSame.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithHighRankSame + + +def test_class7(op_type, typename): + class TestGatherNdOpWithHighRankDiff(OpTest): + #Both Index and X have high rank, Rank(Index) < Rank(X) + + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "gather_nd" + shape = (2, 3, 4, 1, 10) + xnp = np.random.rand(*shape).astype(typename) + index = np.vstack( + [np.random.randint( + 0, s, size=200) for s in shape]).T + index_re = index.reshape([20, 5, 2, 5]) + + self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} + self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if typename == "float16": + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + cls_name = "{0}_{1}_7".format(op_type, typename) + TestGatherNdOpWithHighRankDiff.__name__ = cls_name + globals()[cls_name] = TestGatherNdOpWithHighRankDiff + + +class TestGatherNdAPI(unittest.TestCase): + def test_imperative(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]) + index_1 = np.array([[1]]) + input = fluid.dygraph.to_variable(input_1) + index = fluid.dygraph.to_variable(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([3, 4]) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + +for _typename in {'float16', 'float32'}: + test_class1('gather_nd', _typename) + test_class2('gather_nd', _typename) + test_class3('gather_nd', _typename) + test_class4('gather_nd', _typename) + test_class5('gather_nd', _typename) + test_class6('gather_nd', _typename) + test_class7('gather_nd', _typename) + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py new file mode 100644 index 0000000000000..6e5b4c012053f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLabelSmoothOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "label_smooth" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + + self.set_inputs() + self.set_attrs() + self.set_outputs() + + def calc_out(self, label, epsilon, dist=None): + label_dim = label.shape[-1] + y = (1 - epsilon) * label + if dist is not None: + y += epsilon * dist + else: + y += epsilon / label_dim + return y.astype(self.dtype) + + def set_inputs(self): + batch_size, label_dim = 10, 12 + x = np.zeros((batch_size, label_dim)).astype(self.dtype) + nonzero_index = np.random.randint(label_dim, size=(batch_size)) + x[np.arange(batch_size), nonzero_index] = 1 + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + + def set_attrs(self): + epsilon = 0.1 + self.attrs = {"epsilon": epsilon} + + def set_outputs(self): + dist = None if 'PriorDist' not in self.inputs else self.inputs[ + 'PriorDist'] + out = self.calc_out(self.inputs['X'], self.attrs['epsilon'], dist) + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): + def set_inputs(self): + super(TestLabelSmoothOpWithPriorDist, self).set_inputs() + label_dim = self.inputs['X'].shape[-1] + dist = np.random.random((1, label_dim)).astype(self.dtype) + self.inputs['PriorDist'] = dist + + +class TestLabelSmoothOp3D(TestLabelSmoothOp): + def set_inputs(self): + super(TestLabelSmoothOp3D, self).set_inputs() + self.inputs['X'].reshape([2, -1, self.inputs['X'].shape[-1]]) + + +class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist): + def set_inputs(self): + super(TestLabelSmoothOpWithPriorDist3D, self).set_inputs() + self.inputs['X'].reshape([2, -1, self.inputs['X'].shape[-1]]) + + +class TestLabelSmoothOpFP16(TestLabelSmoothOp): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOpWithPriorDistFP16(TestLabelSmoothOpWithPriorDist): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOp3DFP16(TestLabelSmoothOp3D): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOpWithPriorDist3DFP16(TestLabelSmoothOpWithPriorDist3D): + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py new file mode 100644 index 0000000000000..e8b680d1ddc1b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_log_softmax_op_npu.py @@ -0,0 +1,95 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +import paddle.nn.functional as F +from test_log_softmax import ref_log_softmax, ref_log_softmax_grad +paddle.enable_static() +np.random.seed(10) + + +class TestLogSoftmaxNPUOp(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "log_softmax" + self.dtype = np.float32 + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.set_attrs() + self.set_dtype() + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def set_attrs(self): + pass + + def set_dtype(self): + pass + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + pass + + +def test_class(op_type, typename): + class TestLogSoftmaxShape(TestLogSoftmaxNPUOp): + def set_attrs(self): + self.shape = [12, 10] + + def set_dtype(self): + self.dtype = typename + + cls_name = "{0}_{1}_1".format(op_type, typename) + TestLogSoftmaxShape.__name__ = cls_name + globals()[cls_name] = TestLogSoftmaxShape + + +def test_class2(op_type, typename): + class TestLogSoftmaxAxis(TestLogSoftmaxNPUOp): + def set_attrs(self): + self.axis = 0 + + def set_dtype(self): + self.dtype = typename + + cls_name = "{0}_{1}_2".format(op_type, typename) + + TestLogSoftmaxAxis.__name__ = cls_name + globals()[cls_name] = TestLogSoftmaxAxis + + +for _typename in {'float32'}: + test_class("logsoftmax", _typename) + test_class2("logsoftmax", _typename) +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py b/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py new file mode 100644 index 0000000000000..ce7e962624a46 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_shard_index_op.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.framework as framework +from paddle.fluid.framework import Program, program_guard +import paddle +paddle.enable_static() +SEED = 2021 + + +def common_setup(self, index_num, nshards, shard_id, ignore_value): + self.__class__.use_npu = True + self.__class__.op_type = "shard_index" + + self.op_type = 'shard_index' + x_lod = [[i for i in range(10)]] + N = sum(x_lod[0]) + x = [np.random.randint(0, index_num - 1) for i in range(N)] + x = np.array(x).astype('int32').reshape([N, 1]) + + shard_size = (index_num + nshards - 1) // nshards + out = np.zeros(shape=x.shape).astype('int32') + for i in range(N): + if x[i] // shard_size == shard_id: + out[i] = x[i] % shard_size + else: + out[i] = ignore_value + + self.inputs = {'X': (x, x_lod)} + self.attrs = { + 'index_num': index_num, + 'nshards': nshards, + 'shard_id': shard_id, + 'ignore_value': ignore_value + } + self.outputs = {'Out': (out, x_lod)} + + +class TestShardIndexShardId0Op(OpTest): + def setUp(self): + common_setup(self, 20, 2, 0, -1) + + def test_check_output(self): + return self.check_output_with_place(place=paddle.NPUPlace(0)) + + +class TestShardIndexShardId1Op(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 20, 2, 1, -1) + + +class TestShardIndexIgnoreValueOp(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 20, 2, 0, -2) + + +class TestShardIndexNotEvenlyDividedOp(TestShardIndexShardId0Op): + def setUp(self): + common_setup(self, 15, 2, 1, -1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py index 2f0fa697cb0d9..1260017da939c 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py @@ -56,11 +56,11 @@ def strided_slice_native_forward(input, axes, starts, ends, strides): return result -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp(OpTest): def setUp(self): self.initTestCase() + self.set_npu() + self.place = paddle.NPUPlace(0) self.op_type = 'strided_slice' self.output = strided_slice_native_forward( self.input, self.axes, self.starts, self.ends, self.strides) @@ -75,12 +75,17 @@ def setUp(self): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') def initTestCase(self): - self.input = np.random.rand(10) + self.input = np.random.rand(100) self.axes = [0] self.starts = [2] self.ends = [7] @@ -283,12 +288,12 @@ def initTestCase(self): self.infer_flags = [1, 1, 1, 1, 1] -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_starts_ListTensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() starts_tensor = [] for index, ele in enumerate(self.starts): @@ -305,6 +310,9 @@ def setUp(self): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -318,16 +326,18 @@ def config(self): self.starts_infer = [1, 10, 2] def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_ends_ListTensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() ends_tensor = [] for index, ele in enumerate(self.ends): @@ -344,6 +354,9 @@ def setUp(self): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 0] @@ -357,16 +370,19 @@ def config(self): self.ends_infer = [3, 1, 4] def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_starts_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() + self.inputs = { 'Input': self.input, "StartsTensor": np.array( @@ -381,6 +397,9 @@ def setUp(self): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -392,16 +411,19 @@ def config(self): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_ends_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() + self.inputs = { 'Input': self.input, "EndsTensor": np.array( @@ -416,6 +438,9 @@ def setUp(self): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -427,20 +452,23 @@ def config(self): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_listTensor_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) + self.op_type = "strided_slice" + self.set_npu() self.config() + ends_tensor = [] for index, ele in enumerate(self.ends): ends_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) - self.op_type = "strided_slice" self.inputs = { 'Input': self.input, @@ -457,6 +485,9 @@ def setUp(self): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -468,16 +499,19 @@ def config(self): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_strides_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" + self.set_npu() self.config() + self.inputs = { 'Input': self.input, "StridesTensor": np.array( @@ -492,6 +526,9 @@ def setUp(self): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, -1, 2] @@ -503,8 +540,10 @@ def config(self): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') # Test python API diff --git a/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_min_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_min_op_npu.py new file mode 100644 index 0000000000000..18e2db7f6b1d9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_min_op_npu.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import os +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn +from test_update_loss_scaling_op_npu import TestUpdateLossScalingOpBad + +paddle.enable_static() +SEED = 2021 + + +class TestUpdateLossScalingOpMinLossScalingBad(TestUpdateLossScalingOpBad): + def setUp(self): + self.set_npu() + self.op_type = "update_loss_scaling" + self.place = paddle.NPUPlace(0) + + self.init() + fluid.core.globals()['FLAGS_min_loss_scaling'] = 1639 + found_inf = np.array([True], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + x[i[0]][j[0]] = np.inf + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', np.zeros_like(x))], + 'LossScaling': np.array([1639.0]).astype(self.dtype), + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def init(self): + self.incr_ratio = 2.0 + self.decr_ratio = 0.8 + self.dtype = np.float32 + self.prev_loss_scaling = np.array([2048]).astype(self.dtype) + self.num_good_steps = np.array([999], dtype=np.int32) + self.num_bad_steps = np.array([1], dtype=np.int32) + self.zero_steps = np.array([0], dtype=np.int32) + self.attrs = { + 'incr_every_n_steps': 1000, + 'decr_every_n_nan_or_inf': 2, + 'incr_ratio': self.incr_ratio, + 'decr_ratio': self.decr_ratio, + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py new file mode 100755 index 0000000000000..f1049084cfb79 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -0,0 +1,948 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import unittest.mock +from io import StringIO +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.tensor as tensor +from paddle.fluid import layers +from paddle.nn.layer.transformer import _convert_param_attr_to_list +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.context import set_default_distributed_context +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import _get_comm_group +from paddle.distributed.auto_parallel.process import new_process_group + +paddle.enable_static() +_global_parallel_stratergy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) + + +def get_programs(annotated_func): + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + train_program, start_program = annotated_func(train_program, start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + rank_id = 3 + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog = partitioner.transpile_forward( + complete_train_program, start_program) + + return complete_train_program, start_program, test_auto_parallel_dist_main_prog, test_auto_parallel_dist_startup_prog, dist_context + + +def is_all_parameters_shape_equal(prog1, prog2): + + params1 = prog1.all_parameters() + params2 = prog2.all_parameters() + params1.sort(key=lambda x: x.name) + params2.sort(key=lambda x: x.name) + shape1 = [tensor.shape for tensor in params1] + shape2 = [tensor.shape for tensor in params2] + + if len(shape1) != len(shape2): + return False + for i in range(len(shape1)): + if shape1[i] != shape2[i]: + return False + return True + + +def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): + + for i in range(len(varnames1)): + var1 = prog1.global_block().var(varnames1[i]) + var2 = prog2.global_block().var(varnames2[i]) + if var1.shape[axis] != (var2.shape[axis] // nsplit): + return False + + return True + + +def initialization_check(mode, dist_context, dist_startup_prog, + serial_startup_prog, var_need_broadcast): + if 'mp' in mode: + mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, mp_parallel_axis, + 3) + mp_ring_id = new_process_group(group_ranks).id + broadcast_ops = [ + op for op in dist_startup_prog.global_block().ops + if (op.type == "c_broadcast" and op.desc.attr("ring_id") == + mp_ring_id) + ] + broadcast_varnames = sorted( + [op.desc.output_arg_names()[0] for op in broadcast_ops]) + if broadcast_varnames != var_need_broadcast: + return False + + if 'dp' in mode: + dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, dp_parallel_axis, + 3) + dp_ring_id = new_process_group(group_ranks).id + nparam = len(serial_startup_prog.all_parameters()) + nbroadcast_dp = len([ + op for op in dist_startup_prog.global_block().ops + if (op.type == "c_broadcast" and op.desc.attr("ring_id") == + dp_ring_id) + ]) + if nparam != nbroadcast_dp: + return False + + if "dp" in mode and 'mp' in mode: + nbroadcast = len([ + op for op in dist_startup_prog.global_block().ops + if op.type == "c_broadcast" + ]) + if len(var_need_broadcast) + nbroadcast_dp != nbroadcast: + return False + + return True + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + else: + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, + dim_mapping=[-1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + + return out + + +def mlp_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + out = mlp(input) + return train_program, start_program + + +class TestMLPAutoPartitioner(unittest.TestCase): + def test_mlp_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # parameter should not be partitioned + self.assertTrue( + is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) + self.assertTrue( + is_all_parameters_shape_equal(serial_startup_prog, + dist_startup_prog)) + + # op in main prog should be the same + serial_ops = serial_main_prog.global_block().ops + dist_ops = dist_main_prog.global_block().ops + serial_ops = [op.type for op in serial_ops] + dist_ops = [op.type for op in dist_ops] + self.assertTrue(serial_ops == dist_ops) + + # parameter initialization + var_need_broadcast = [] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_mlp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_1.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_1.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted( + ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_mlp_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + mlp_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_1.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_1.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'layer_norm', 'c_identity', 'matmul', 'elementwise_add', 'gelu', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted( + ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + +class AttentionLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + sequence_len=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(AttentionLayer, self).__init__() + self.hidden_size = hidden_size + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + def forward(self, input): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + q = self.q_proj(input) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(input) + v = self.v_proj(input) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + return out + + +def attn_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="query", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + attn = AttentionLayer( + hidden_size=hidden_size, + sequence_len=sequence_len, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = attn(input) + + return train_program, start_program + + +class TestAttentionAutoPartitioner(unittest.TestCase): + def test_attn_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + # parameter should not be partitioned + self.assertTrue( + is_all_parameters_shape_equal(serial_main_prog, dist_main_prog)) + self.assertTrue( + is_all_parameters_shape_equal(serial_startup_prog, + dist_startup_prog)) + + # op in main prog should be the same + serial_ops = serial_main_prog.global_block().ops + dist_ops = dist_main_prog.global_block().ops + serial_ops = [op.type for op in serial_ops] + dist_ops = [op.type for op in dist_ops] + self.assertTrue(serial_ops == dist_ops) + + # parameter initialization + var_need_broadcast = [] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_attn_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_3.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_3.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', + 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', + 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', + 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = ['linear_3.b_0'] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_attn_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + attn_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = ['linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = ['linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['linear_3.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = ['linear_3.b_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_identity', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'c_identity', 'matmul', 'elementwise_add', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'reshape2', + 'transpose2', 'matmul', 'softmax', 'dropout', 'matmul_v2', + 'transpose2', 'reshape2', 'matmul', 'c_allreduce_sum', + 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = ['linear_3.b_0'] + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + +class DecoderLayer(nn.Layer): + def __init__(self, + vocab_size=32768, + hidden_size=1024, + sequence_len=512, + max_position_embeddings=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(DecoderLayer, self).__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.max_position_embeddings = max_position_embeddings + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.word_embeddings = nn.Embedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + self.position_embeddings = nn.Embedding( + self.max_position_embeddings, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + intermediate_size = 4 * self.hidden_size + d_model = self.hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(self.dropout_ratio) + self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + + def forward(self, input_ids, position_ids): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + input_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + embeddings = input_embeddings + position_embeddings + embeddings = self.dropout1(embeddings) + + # Pre-norm + target = self.norm(embeddings) + + # The following is the attention part + q = self.q_proj(target) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(target) + v = self.v_proj(target) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + else: + auto.shard_tensor( + self.out_proj.weight, + _global_process_mesh, + dim_mapping=[-1, -1]) + + # Add residual + residual = embeddings + self.dropout2(out) + + # Pre-norm + out0 = self.norm(residual) + + # The following is the MLP part + out1 = self.linear0(out0) + out2 = F.gelu(out1, approximate=True) + out3 = self.linear1(out2) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # Add residual + final = residual + self.dropout3(out3) + return final + + +def decoder_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + decoder = DecoderLayer( + vocab_size=32768, + hidden_size=hidden_size, + sequence_len=sequence_len, + max_position_embeddings=512, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = decoder(input_ids, position_ids) + + return train_program, start_program + + +class TestDecoderLayerPartitioner(unittest.TestCase): + def test_decoder_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + decoder_pretrain_forward) + + # param should be partition + nrank = 4 + # col parallel + weights = [ + 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = [ + 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = [ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'c_embedding', 'c_allreduce_sum', 'lookup_table_v2', + 'elementwise_add', 'dropout', 'layer_norm', 'c_identity', 'matmul', + 'elementwise_add', 'reshape2', 'transpose2', 'c_identity', 'matmul', + 'elementwise_add', 'c_identity', 'matmul', 'elementwise_add', + 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', + 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', + 'matmul', 'c_allreduce_sum', 'elementwise_add', 'dropout', + 'elementwise_add', 'layer_norm', 'c_identity', 'matmul', + 'elementwise_add', 'gelu', 'matmul', 'c_allreduce_sum', + 'elementwise_add', 'dropout', 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + + # parameter initialization + var_need_broadcast = sorted([ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ]) + self.assertTrue( + initialization_check(_global_parallel_stratergy, dist_context, + dist_startup_prog, serial_startup_prog, + var_need_broadcast)) + + def test_decoder_noparallel(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "None" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( + decoder_pretrain_forward) + + # param should be partition + nrank = 1 + # col parallel + weights = [ + 'linear_0.w_0', 'linear_1.w_0', 'linear_2.w_0', 'linear_4.w_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 1, nrank)) + weights = [ + 'linear_0.b_0', 'linear_1.b_0', 'linear_2.b_0', 'linear_4.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + # row parallel + weights = ['word_embeddings', 'linear_3.w_0', 'linear_5.w_0'] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, nrank)) + weights = [ + 'linear_3.b_0', 'pos_embeddings', 'layer_norm_0.b_0', + 'layer_norm_0.w_0', 'linear_5.b_0' + ] + self.assertTrue( + check_tensor_split(dist_main_prog, weights, serial_main_prog, + weights, 0, 1)) + + # row and col allreduce + dist_ops = dist_main_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'lookup_table_v2', 'lookup_table_v2', 'elementwise_add', 'dropout', + 'layer_norm', 'matmul', 'elementwise_add', 'reshape2', 'transpose2', + 'matmul', 'elementwise_add', 'matmul', 'elementwise_add', + 'reshape2', 'transpose2', 'reshape2', 'transpose2', 'matmul', + 'softmax', 'dropout', 'matmul_v2', 'transpose2', 'reshape2', + 'matmul', 'elementwise_add', 'dropout', 'elementwise_add', + 'layer_norm', 'matmul', 'elementwise_add', 'gelu', 'matmul', + 'elementwise_add', 'dropout', 'elementwise_add' + ] + self.assertTrue(dist_ops == ref_ops) + dist_ops = dist_startup_prog.global_block().ops + dist_ops = [op.type for op in dist_ops] + ref_ops = [ + 'gaussian_random', 'gaussian_random', 'gaussian_random', + 'fill_constant', 'gaussian_random', 'fill_constant', + 'gaussian_random', 'fill_constant', 'gaussian_random', + 'fill_constant', 'gaussian_random', 'fill_constant', + 'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant' + ] + self.assertTrue(dist_ops == ref_ops) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py new file mode 100755 index 0000000000000..b02c5f8a84f32 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -0,0 +1,857 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import math +import unittest + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.tensor as tensor +import paddle.utils as utils +from paddle.fluid import layers +from paddle.fluid.framework import in_dygraph_mode +from paddle.nn.layer.transformer import _convert_param_attr_to_list +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.distributed import fleet +import paddle.static as static +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.utils import _get_comm_group +from paddle.distributed.auto_parallel.process import new_process_group + +paddle.enable_static() +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) +_global_parallel_stratergy = None +_global_process_mesh = None + + +def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): + + for i in range(len(varnames1)): + var1 = prog1.global_block().var(varnames1[i] + '@GRAD') + var2 = prog2.global_block().var(varnames2[i]) + if var1.shape[axis] != (var2.shape[axis] // nsplit): + return False + + return True + + +class MultiHeadAttention(nn.Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + """ + + Cache = collections.namedtuple("Cache", ["k", "v"]) + StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + kdim=None, + vdim=None, + need_weights=False, + weight_attr=None, + bias_attr=None, + topo=None, + fuse=False): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.need_weights = need_weights + self.fuse = fuse + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if topo is None or topo.mp_info.size == 1: + if self.fuse: + assert self.kdim == embed_dim + assert self.vdim == embed_dim + self.qkv_proj = nn.Linear( + embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) + else: + self.q_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + + def _fuse_prepare_qkv(self, query): + mix_layer = self.qkv_proj(query) + mix_layer = paddle.reshape_(mix_layer, + [0, 0, self.num_heads, 3 * self.head_dim]) + mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3]) + q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1) + return q, k, v + + def _prepare_qkv(self, query, key, value, use_cache=False, cache=None): + r""" + Prapares linear projected queries, keys and values for usage of subsequnt + multiple parallel attention. If `cache` is not None, using cached results + to reduce redundant calculations. + """ + q = self.q_proj(query) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = tensor.concat([cache.k, k], axis=2) + v = tensor.concat([cache.v, v], axis=2) + if use_cache is True: + cache = self.Cache(k, v) + + return (q, k, v) if use_cache is False else (q, k, v, cache) + + def compute_kv(self, key, value): + r""" + Applies linear projection on input keys and values, then splits heads + (reshape and transpose) to get keys and values from different representation + subspaces. The results are used as key-values pairs for subsequent multiple + parallel attention. + It is part of calculations in multi-head attention, and is provided as + a method to pre-compute and prefetch these results, thus we can use them + to construct cache for inference. + """ + k = self.k_proj(key) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + v = self.v_proj(value) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + return k, v + + def gen_cache(self, key, value=None, type=Cache): + """ + Generates cache for `forward` usage in inference accroding to arguments. + The generated cache is an instance of `MultiHeadAttention.Cache` or an + instance of `MultiHeadAttention.StaticCache`. + """ + if type == MultiHeadAttention.StaticCache: # static_kv + k, v = self.compute_kv(key, value) + return self.StaticCache(k, v) + elif value is None: # incremental_state + k = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + v = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + return self.Cache(k, v) + else: + # incremental_state with initial value, mainly for usage like UniLM + return self.Cache(key, value) + + def forward(self, + query, + key, + value, + attn_mask=None, + use_cache=False, + cache=None): + r""" + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if use_cache is False: + if self.fuse: + q, k, v = self._fuse_prepare_qkv(query) + else: + q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, + cache) + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + product = product + attn_mask + + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + outs = [out] + if self.need_weights: + outs.append(weights) + if use_cache: + outs.append(cache) + return out if len(outs) == 1 else tuple(outs) + + +class TransformerDecoder(nn.Layer): + """ + TransformerDecoder is a stack of N decoder layers. + """ + + def __init__(self, + decoder_layers, + num_layers, + norm=None, + hidden_size=None, + topo=None): + super(TransformerDecoder, self).__init__() + + self.topo = topo + self.num_layers = num_layers + self.layers = decoder_layers + self.norm = norm + if norm is "LayerNorm": + self.norm = nn.LayerNorm(hidden_size) + elif norm is not None: + raise ValueError("Only support LayerNorm") + self.checkpoints = [] + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + use_cache=False, + cache=None): + r""" + Applies a stack of N Transformer decoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last decoder + layer. + """ + output = tgt + new_caches = [] + self.checkpoints = [] + + for i, mod in enumerate(self.layers): + if cache is None: + if use_cache: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + new_caches.append(new_cache) + else: + output = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + + else: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache[i]) + new_caches.append(new_cache) + self.checkpoints.append(output.name) + + if self.norm is not None: + output = self.norm(output) + return output if use_cache is False else (output, new_caches) + + def gen_cache(self, memory, do_zip=False): + r""" + Generates cache for `forward` usage. The generated cache is a list, and + each element in it is a tuple( :code:`(incremental_cache, static_cache)` ) + produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache` + for more details. If `do_zip` is True, apply `zip` on these tuples to get + a list with two elements. + """ + cache = [layer.gen_cache(memory) for layer in self.layers] + if do_zip: + cache = list(zip(*cache)) + return cache + + +class TransformerDecoderLayer(nn.Layer): + """ + The transformer decoder layer. + It contains multiheadattention and some linear layers. + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout=0.1, + activation="gelu", + attn_dropout=None, + act_dropout=None, + normalize_before=True, + weight_attr=None, + bias_attr=None, + topo=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + + super(TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 3) + bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + + self.self_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0], + topo=topo) + if topo is None or topo.mp_info.size == 1: + self.linear1 = nn.Linear( + d_model, + dim_feedforward, + weight_attrs[2], + bias_attr=bias_attrs[2]) + self.linear2 = nn.Linear( + dim_feedforward, + d_model, + weight_attrs[2], + bias_attr=bias_attrs[2]) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + + def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None): + residual = tgt + + if self.normalize_before: + tgt = self.norm1(tgt) + + if use_cache is False: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + else: + tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, + use_cache, cache) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # tgt = self.dropout2( + # self.linear2(F.gelu( + # self.linear1(tgt), approximate=True))) + tgt = self.linear1(tgt) + tgt = F.gelu(tgt, approximate=True) + tgt = self.dropout2(self.linear2(tgt)) + tgt = residual + tgt + + if not self.normalize_before: + tgt = self.norm2(tgt) + + return tgt if use_cache is False else (tgt, incremental_cache) + + def gen_cache(self, memory): + incremental_cache = self.self_attn.gen_cache( + memory, type=self.self_attn.Cache) + return incremental_cache + + +class GPTEmbeddings(nn.Layer): + """ + Include embeddings from word, position and token_type embeddings + """ + + def __init__(self, + vocab_size, + hidden_size=768, + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + topo=None): + super(GPTEmbeddings, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.word_embeddings = nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + self.position_embeddings = nn.Embedding( + max_position_embeddings, + hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, position_ids=None): + if position_ids is None: + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + position_ids = seq_length - ones + + input_embedings = self.word_embeddings(input_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + position_embeddings = self.position_embeddings(position_ids) + embeddings = input_embedings + position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +class GPTModel(nn.Layer): + """ + The base model of gpt. + """ + + def __init__(self, + vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None): + super(GPTModel, self).__init__() + + self.pad_token_id = pad_token_id + self.initializer_range = initializer_range + self.topo = topo + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + self.pipline_mode = topo is not None and topo.pp_info.size > 1 + if self.pipline_mode: + self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size + + self.embeddings = GPTEmbeddings( + vocab_size, hidden_size, hidden_dropout_prob, + max_position_embeddings, type_vocab_size, self.initializer_range, + topo) + + decoder_layers = nn.LayerList() + for i in range(num_hidden_layers): + DecoderLayer = TransformerDecoderLayer + decoder_layers.append( + DecoderLayer( + d_model=hidden_size, + nhead=num_attention_heads, + dim_feedforward=intermediate_size, + dropout=hidden_dropout_prob, + activation=hidden_act, + attn_dropout=attention_probs_dropout_prob, + act_dropout=hidden_dropout_prob, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)), + bias_attr=None, + topo=topo)) + + Decoder = TransformerDecoder + + self.decoder = Decoder( + decoder_layers, + num_hidden_layers, + norm="LayerNorm", + hidden_size=hidden_size, + topo=topo) + + self.checkpoints = [] + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + use_cache=False, + cache=None): + self.checkpoints = [] + if attention_mask is None: + length = paddle.shape(input_ids)[1] + # Use bool mask + attention_mask = paddle.tensor.tril( + paddle.ones( + (length, length), + dtype=self.embeddings.word_embeddings.weight.dtype)) + if position_ids is None: + past_length = 0 + if cache is not None: + past_length = paddle.shape(cache[0].k)[-2] + position_ids = paddle.arange( + past_length, + paddle.shape(input_ids)[-1] + past_length, + dtype='int64') + position_ids = position_ids.unsqueeze(0) + # .expand_as(input_ids) + position_ids = paddle.fluid.layers.expand_as(position_ids, + input_ids) + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids) + + # TODO, use registered buffer + causal_mask = paddle.tensor.triu( + paddle.ones((paddle.shape(input_ids)[-1], + paddle.shape(input_ids)[-1])) * -1e9, + diagonal=1) + + if attention_mask is not None: + attention_mask = attention_mask + causal_mask + else: + attention_mask = causal_mask + + # The tensor returned by triu not in static graph. + attention_mask.stop_gradient = True + + encoder_outputs = self.decoder( + embedding_output, + memory=None, + tgt_mask=attention_mask, + use_cache=use_cache, + cache=cache) + self.checkpoints.extend(self.decoder.checkpoints) + return encoder_outputs + + +class GPTForPretraining(nn.Layer): + """ + The pretraining model of GPT. + It returns some logits and cached_kvs. + """ + + def __init__(self, gpt): + super(GPTForPretraining, self).__init__() + self.gpt = gpt + self.share_param = False + self.weight = self.gpt.embeddings.word_embeddings.weight + if not self.share_param: + self.weight = self.create_parameter(shape=self.weight.shape) + + def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo): + if topo is not None and topo.mp_info.size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=None) + + logits = paddle.matmul( + input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=None) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + masked_positions=None, + use_cache=False, + cache=None): + outputs = self.gpt(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + if use_cache: + encoder_outputs, cached_kvs = outputs[:2] + else: + encoder_outputs = outputs + logits = self.parallel_matmul(encoder_outputs, self.weight, True, + self.gpt.topo) + + if use_cache: + return logits, cached_kvs + else: + return logits + + +class GPTPretrainingCriterion(nn.Layer): + """ + Criterion for GPT. + It calculates the final loss. + """ + + def __init__(self, topo=None): + super(GPTPretrainingCriterion, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + else: + self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy + + def forward(self, prediction_scores, masked_lm_labels, loss_mask): + masked_lm_loss = self.loss_func(prediction_scores, + masked_lm_labels.unsqueeze(2)) + + loss_mask = loss_mask.reshape([-1]) + masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) + loss = masked_lm_loss / loss_mask.sum() + return loss + + +def gpt_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 16 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float64') + labels = static.data( + name="labels", shape=[batch_size, sequence_len], dtype='int64') + loss_mask = static.data( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + gpt = GPTModel( + vocab_size=32768, + hidden_size=768, + num_hidden_layers=2, + num_attention_heads=12, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None) + + model = GPTForPretraining(gpt) + + preds = model(input_ids, position_ids, attention_mask) + + criterion = GPTPretrainingCriterion() + + loss = criterion(preds, labels, loss_mask) + + return train_program, start_program, loss + + +class TestGPTPartitioner(unittest.TestCase): + def test_gpt_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + dist_context.set_process_mesh(_global_process_mesh) + train_program, start_program, loss = gpt_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + rank_id = 3 + dist_strategy = fleet.DistributedStrategy() + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, start_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, start_program, + auto_parallel_main_prog, auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + + nrank = 4 + # col parallel + weights = [ + 'linear_0.w_0', + 'linear_6.w_0', + 'linear_10.w_0', + ] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 1, nrank)) + + # row parallel + weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0'] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 0, nrank)) + + weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0'] + self.assertTrue( + check_tensor_split(auto_parallel_main_prog, weights, + complete_train_program, weights, 0, 1)) + + all_params = sorted( + [param.name for param in start_program.all_parameters()]) + allreduce_grads = [ + 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', + 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', + 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' + ] + mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, mp_parallel_axis, + 3) + mp_ring_id = new_process_group(group_ranks).id + dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info() + group_ranks = _get_comm_group(process_mesh.process_group, + process_mesh.topology, dp_parallel_axis, + 3) + dp_ring_id = new_process_group(group_ranks).id + tensor_parallel_allreduce_vars = sorted([ + op.desc.output_arg_names()[0].split("@")[0] + for op in auto_parallel_main_prog.global_block().ops + if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and + op.desc.attr("ring_id") == mp_ring_id) + ]) + data_parallel_allreduce_vars = sorted([ + op.desc.output_arg_names()[0].split("@")[0] + for op in auto_parallel_main_prog.global_block().ops + if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") == + dp_ring_id) + ]) + + self.assertTrue(all_params == data_parallel_allreduce_vars) + self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_flatten2_op.py b/python/paddle/fluid/tests/unittests/test_flatten2_op.py index a3c12a5fc01c3..42b43cc46a69b 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten2_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten2_op.py @@ -17,6 +17,7 @@ import unittest import numpy as np import paddle.fluid as fluid +import paddle from op_test import OpTest @@ -69,6 +70,20 @@ def init_test_case(self): self.new_shape = (36, 16) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, axis=1): + return fluid.layers.flatten(x, axis=axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, axis=2) + self.assertTrue((-1, -1) == out.shape) + + class TestFlatten2OpError(unittest.TestCase): def test_errors(self): with fluid.program_guard(fluid.Program(), fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index f87b732d1b2cc..9093050d6d5c6 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -201,6 +201,20 @@ def test_static_api(self): self.assertTrue((2, 3, 16) == fetch_out[0].shape) +class TestStaticFlattenInferShapePythonAPI(unittest.TestCase): + def execute_api(self, x, start_axis=0, stop_axis=-1): + return paddle.flatten(x, start_axis, stop_axis) + + def test_static_api(self): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, paddle.static.Program()): + x = paddle.static.data( + name="x", shape=[-1, 3, -1, -1], dtype='float32') + out = self.execute_api(x, start_axis=2, stop_axis=3) + self.assertTrue((-1, 3, -1) == out.shape) + + class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI): def execute_api(self, x, start_axis=0, stop_axis=-1): return x.flatten_(start_axis, stop_axis) diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 3ec943ef2e04a..7740cc0b03b49 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -18,7 +18,7 @@ import numpy as np import math import functools -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci from paddle.fluid.tests.unittests.test_lstm_op import ACTIVATION from paddle import fluid from paddle.fluid import Program, program_guard @@ -106,6 +106,9 @@ class TestGRUOp(OpTest): def set_confs(self): pass + def set_is_test(self): + self.is_test = False + def setUp(self): self.op_type = "gru" self.lod = [[2, 4, 3]] @@ -118,6 +121,7 @@ def setUp(self): self.dtype = 'float64' self.origin_mode = False self.set_confs() + self.set_is_test() T = sum(self.lod[0]) N = len(self.lod[0]) @@ -153,7 +157,8 @@ def setUp(self): 'activation': self.act_state, 'gate_activation': self.act_gate, 'is_reverse': self.is_reverse, - 'origin_mode': self.origin_mode + 'origin_mode': self.origin_mode, + 'is_test': self.is_test } def test_check_output(self): @@ -229,6 +234,21 @@ def set_confs(self): self.origin_mode = True +class TestGRUOpInference(TestGRUOp): + def set_is_test(self): + self.is_test = True + + def test_check_output(self): + new_outputs = {} + new_outputs['Hidden'] = self.outputs['Hidden'] + self.outputs = new_outputs + super(TestGRUOpInference, self).test_check_output() + + # avoid checking gradient + def test_check_grad(self): + pass + + class TestGruOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_lstm_op.py b/python/paddle/fluid/tests/unittests/test_lstm_op.py index 185255439cc26..fff5fef29221e 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci from paddle import fluid from paddle.fluid.layers import lstm as LSTM from paddle.fluid.layers import fill_constant @@ -212,10 +212,14 @@ def test_pre_cell_type(): class TestLstmOp(OpTest): + def set_is_test(self): + self.is_test = False + def set_lod(self): self.lod = [[2, 3, 2]] def set_argument(self): + self.set_is_test() self.set_lod() self.D = 16 @@ -269,7 +273,8 @@ def setUp(self): 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, - 'candidate_activation': self.act_cand + 'candidate_activation': self.act_cand, + 'is_test': self.is_test } def test_check_output(self): @@ -302,6 +307,15 @@ def set_lod(self): self.lod = [[2, 0, 4]] +class TestLstmOpInference(TestLstmOp): + def set_is_test(self): + self.is_test = True + + # avoid checking gradient + def test_check_grad(self): + pass + + class TestLstmOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index f69993c52ae5d..a80dc87525ab8 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -705,7 +705,7 @@ def test_axis_less_than_zero(self): np_slice = x_arr[:, :, 0:1] self.assertTrue(np.array_equal(pp_slice, np_slice)) - pp_slice = paddle.slice(x, [-100, ], [0], [1]) + pp_slice = paddle.slice(x, (-100, ), [0], [1]) np_slice = x_arr[0:1] self.assertTrue(np.array_equal(pp_slice, np_slice)) diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py b/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py index 8d665a1746816..bc8c3cc5b23e5 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py @@ -81,10 +81,13 @@ def test_ptb_rnn_cpu_bfloat16(self): y_data = np.arange(1, 13).reshape(4, 3).astype('int64') x_data = x_data.reshape((-1, num_steps, 1)) y_data = y_data.reshape((-1, 1)) + #TODO investigate initializing model with "float32" instead of "uint16" as it was before + # slice_op PR(datatypes in model graph are different than datatypes during runtime because of that) init_hidden_data = np.zeros( - (num_layers, batch_size, hidden_size), dtype='float32') + (num_layers, batch_size, hidden_size), dtype='uint16') init_cell_data = np.zeros( - (num_layers, batch_size, hidden_size), dtype='float32') + (num_layers, batch_size, hidden_size), dtype='uint16') + fetch_list = [static_loss, static_last_hidden, static_last_cell] out = exe.run(fluid.default_main_program(), feed={ diff --git a/python/paddle/fluid/tests/unittests/test_svd_op.py b/python/paddle/fluid/tests/unittests/test_svd_op.py new file mode 100644 index 0000000000000..c2d712b3d7e65 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_svd_op.py @@ -0,0 +1,292 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +from gradient_checker import grad_check +from decorator_helper import prog_scope + + +class TestSvdOp(OpTest): + def setUp(self): + paddle.enable_static() + self.generate_input() + self.generate_output() + self.op_type = "svd" + assert (hasattr(self, "_output_data")) + self.inputs = {"X": self._input_data} + self.attrs = {'full_matrices': self.get_full_matrices_option()} + self.outputs = { + "U": self._output_data[0], + "S": self._output_data[1], + "VH": self._output_data[2] + } + + def generate_input(self): + """ return a input_data and input_shape + """ + self._input_shape = (100, 1) + self._input_data = np.random.random(self._input_shape).astype("float64") + + def get_full_matrices_option(self): + return False + + def generate_output(self): + assert (hasattr(self, "_input_data")) + self._output_data = np.linalg.svd(self._input_data) + + def test_check_output(self): + self.check_output(no_check_set=['U', 'VH']) + + def test_svd_forward(self): + """ u matmul diag(s) matmul vt must become X + """ + single_input = self._input_data.reshape( + [-1, self._input_shape[-2], self._input_shape[-1]])[0] + paddle.disable_static() + dy_x = paddle.to_tensor(single_input) + dy_u, dy_s, dy_vt = paddle.linalg.svd(dy_x) + dy_out_x = dy_u.matmul(paddle.diag(dy_s)).matmul(dy_vt) + if (paddle.abs(dy_out_x - dy_x) < 1e-7).all(): + ... + else: + print("EXPECTED:\n", dy_x) + print("GOT :\n", dy_out_x) + raise RuntimeError("Check SVD Failed") + paddle.enable_static() + + def check_S_grad(self): + self.check_grad(['X'], ['S'], numeric_grad_delta=0.001) + + def check_U_grad(self): + self.check_grad(['X'], ['U'], numeric_grad_delta=0.001) + + def check_V_grad(self): + self.check_grad(['X'], ['VH'], numeric_grad_delta=0.001) + + def test_check_grad(self): + """ + remember the input matrix must be the full rank matrix, otherwise the gradient will stochatic because the u / v 's (n-k) freedom vectors + """ + self.check_S_grad() + self.check_U_grad() + self.check_V_grad() + + +class TestSvdCheckGrad2(TestSvdOp): + # NOTE(xiongkun03): because we want to construct some full rank matrics, + # so we can't specifize matrices which numel() > 100 + + no_need_check_grad = True + + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (5, 5) + self._input_data = np.vander( + [2, 3, 4, 5, 6]).astype("float64").reshape(self._input_shape) + + +class TestSvdNormalMatrixSmall(TestSvdCheckGrad2): + def generate_input(self): + """ small matrix SVD. + """ + self._input_shape = (1, 1) + self._input_data = np.random.random(self._input_shape).astype("float64") + + +class TestSvdNormalMatrix6x3(TestSvdCheckGrad2): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (6, 3) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + + +class TestSvdNormalMatrix3x6(TestSvdCheckGrad2): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + + +class TestSvdNormalMatrix6x3Batched(TestSvdOp): + def generate_input(self): + self._input_shape = (10, 6, 3) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = np.stack([self._input_data] * 10, axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +class TestSvdNormalMatrix3x6Batched(TestSvdOp): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (10, 3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + self._input_data = np.stack([self._input_data] * 10, axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +class TestSvdNormalMatrix3x3x3x6Batched(TestSvdOp): + def generate_input(self): + """ return a deterministic matrix, the range matrix; + vander matrix must be a full rank matrix. + """ + self._input_shape = (3, 3, 3, 6) + self._input_data = np.array( + [[1.0, 2.0, 3.0], [0.0, 1.0, 5.0], [0.0, 0.0, 6.0], + [2.0, 4.0, 9.0], [3.0, 6.0, 8.0], + [3.0, 1.0, 0.0]]).astype("float64") + self._input_data = self._input_data.transpose((-1, -2)) + self._input_data = np.stack( + [self._input_data, self._input_data, self._input_data], axis=0) + self._input_data = np.stack( + [self._input_data, self._input_data, self._input_data], axis=0) + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + +@skip_check_grad_ci(reason="'check_grad' on large inputs is too slow, " + + "however it is desirable to cover the forward pass") +class TestSvdNormalMatrixBig(TestSvdOp): + def generate_input(self): + """ big matrix SVD. + + """ + self._input_shape = (2, 200, 300) + self._input_data = np.random.random(self._input_shape).astype("float64") + + def test_svd_forward(self): + """ test_svd_forward not support batched input, so disable this test. + """ + pass + + def test_check_grad(self): + pass + + +class TestSvdNormalMatrixBig2(TestSvdOp): + def generate_input(self): + """ big matrix SVD. + """ + self._input_shape = (1, 100) + self._input_data = np.random.random(self._input_shape).astype("float64") + + +class TestSvdNormalMatrixFullMatrices(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_full_matrices(self): + mat_shape = (2, 3) + mat = np.random.random(mat_shape).astype("float64") + x = paddle.to_tensor(mat) + u, s, vh = paddle.linalg.svd(x, full_matrices=True) + assert (u.shape == [2, 2]) + assert (vh.shape == [3, 3]) + x_recover = u.matmul(paddle.diag(s)).matmul(vh[0:2]) + if ((paddle.abs(x_recover - x) > 1e-4).any()): + raise RuntimeError("mat can't be recovered\n") + + +class TestSvdFullMatriceGrad(TestSvdNormalMatrix6x3): + def get_full_matrices_option(self): + return True + + def test_svd_forward(self): + """ test_svd_forward not support full matrices, so disable this test. + """ + pass + + def test_check_grad(self): + """ + remember the input matrix must be the full rank matrix, otherwise the gradient will stochatic because the u / v 's (n-k) freedom vectors + """ + self.check_S_grad() + #self.check_U_grad() // don't check U grad, because U have freedom vector + self.check_V_grad() + + +class TestSvdAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + a = np.random.rand(5, 5) + x = paddle.to_tensor(a) + u, s, vh = paddle.linalg.svd(x) + gt_u, gt_s, gt_vh = np.linalg.svd(a, full_matrices=False) + self.assertTrue(np.allclose(s, gt_s)) + + def test_static(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + a = np.random.rand(5, 5) + x = paddle.fluid.data( + name="input", shape=[5, 5], dtype='float64') + u, s, vh = paddle.linalg.svd(x) + exe = fluid.Executor(place) + gt_u, gt_s, gt_vh = np.linalg.svd(a, full_matrices=False) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[s]) + self.assertTrue(np.allclose(fetches[0], gt_s)) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 416f125caa2b6..c94316c748243 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -946,6 +946,57 @@ def test_tensor_str_shape_with_zero(self): self.assertEqual(a_str, expected) paddle.enable_static() + def test_tensor_str_linewidth(self): + paddle.disable_static(paddle.CPUPlace()) + paddle.seed(2021) + x = paddle.rand([128]) + paddle.set_printoptions( + precision=4, threshold=1000, edgeitems=3, linewidth=80) + a_str = str(x) + + expected = '''Tensor(shape=[128], dtype=float32, place=CPUPlace, stop_gradient=True, + [0.3759, 0.0278, 0.2489, 0.3110, 0.9105, 0.7381, 0.1905, 0.4726, 0.2435, + 0.9142, 0.3367, 0.7243, 0.7664, 0.9915, 0.2921, 0.1363, 0.8096, 0.2915, + 0.9564, 0.9972, 0.2573, 0.2597, 0.3429, 0.2484, 0.9579, 0.7003, 0.4126, + 0.4274, 0.0074, 0.9686, 0.9910, 0.0144, 0.6564, 0.2932, 0.7114, 0.9301, + 0.6421, 0.0538, 0.1273, 0.5771, 0.9336, 0.6416, 0.1832, 0.9311, 0.7702, + 0.7474, 0.4479, 0.3382, 0.5579, 0.0444, 0.9802, 0.9874, 0.3038, 0.5640, + 0.2408, 0.5489, 0.8866, 0.1006, 0.5881, 0.7560, 0.7928, 0.8604, 0.4670, + 0.9285, 0.1482, 0.4541, 0.1307, 0.6221, 0.4902, 0.1147, 0.4415, 0.2987, + 0.7276, 0.2077, 0.7551, 0.9652, 0.4369, 0.2282, 0.0047, 0.2934, 0.4308, + 0.4190, 0.1442, 0.3650, 0.3056, 0.6535, 0.1211, 0.8721, 0.7408, 0.4220, + 0.5937, 0.3123, 0.9198, 0.0275, 0.5338, 0.4622, 0.7521, 0.3609, 0.4703, + 0.1736, 0.8976, 0.7616, 0.3756, 0.2416, 0.2907, 0.3246, 0.4305, 0.5717, + 0.0735, 0.0361, 0.5534, 0.4399, 0.9260, 0.6525, 0.3064, 0.4573, 0.9210, + 0.8269, 0.2424, 0.7494, 0.8945, 0.7098, 0.8078, 0.4707, 0.5715, 0.7232, + 0.4678, 0.5047])''' + + self.assertEqual(a_str, expected) + paddle.enable_static() + + def test_tensor_str_linewidth2(self): + paddle.disable_static(paddle.CPUPlace()) + paddle.seed(2021) + x = paddle.rand([128]) + paddle.set_printoptions(precision=4, linewidth=160, sci_mode=True) + a_str = str(x) + + expected = '''Tensor(shape=[128], dtype=float32, place=CPUPlace, stop_gradient=True, + [3.7587e-01, 2.7798e-02, 2.4891e-01, 3.1097e-01, 9.1053e-01, 7.3811e-01, 1.9045e-01, 4.7258e-01, 2.4354e-01, 9.1415e-01, 3.3666e-01, 7.2428e-01, + 7.6640e-01, 9.9146e-01, 2.9215e-01, 1.3625e-01, 8.0957e-01, 2.9153e-01, 9.5642e-01, 9.9718e-01, 2.5732e-01, 2.5973e-01, 3.4292e-01, 2.4841e-01, + 9.5794e-01, 7.0029e-01, 4.1260e-01, 4.2737e-01, 7.3788e-03, 9.6863e-01, 9.9102e-01, 1.4416e-02, 6.5640e-01, 2.9318e-01, 7.1136e-01, 9.3008e-01, + 6.4209e-01, 5.3849e-02, 1.2730e-01, 5.7712e-01, 9.3359e-01, 6.4155e-01, 1.8320e-01, 9.3110e-01, 7.7021e-01, 7.4736e-01, 4.4793e-01, 3.3817e-01, + 5.5794e-01, 4.4412e-02, 9.8023e-01, 9.8735e-01, 3.0376e-01, 5.6397e-01, 2.4082e-01, 5.4893e-01, 8.8659e-01, 1.0065e-01, 5.8812e-01, 7.5600e-01, + 7.9280e-01, 8.6041e-01, 4.6701e-01, 9.2852e-01, 1.4821e-01, 4.5410e-01, 1.3074e-01, 6.2210e-01, 4.9024e-01, 1.1466e-01, 4.4154e-01, 2.9868e-01, + 7.2758e-01, 2.0766e-01, 7.5508e-01, 9.6522e-01, 4.3688e-01, 2.2823e-01, 4.7394e-03, 2.9342e-01, 4.3083e-01, 4.1902e-01, 1.4416e-01, 3.6500e-01, + 3.0560e-01, 6.5350e-01, 1.2115e-01, 8.7206e-01, 7.4081e-01, 4.2203e-01, 5.9372e-01, 3.1230e-01, 9.1979e-01, 2.7486e-02, 5.3383e-01, 4.6224e-01, + 7.5211e-01, 3.6094e-01, 4.7034e-01, 1.7355e-01, 8.9763e-01, 7.6165e-01, 3.7557e-01, 2.4157e-01, 2.9074e-01, 3.2458e-01, 4.3049e-01, 5.7171e-01, + 7.3509e-02, 3.6087e-02, 5.5341e-01, 4.3993e-01, 9.2601e-01, 6.5248e-01, 3.0640e-01, 4.5727e-01, 9.2104e-01, 8.2688e-01, 2.4243e-01, 7.4937e-01, + 8.9448e-01, 7.0981e-01, 8.0783e-01, 4.7065e-01, 5.7154e-01, 7.2319e-01, 4.6777e-01, 5.0465e-01])''' + + self.assertEqual(a_str, expected) + paddle.enable_static() + def test_print_tensor_dtype(self): paddle.disable_static(paddle.CPUPlace()) a = paddle.rand([1]) diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 2492caff2f91e..584c418675726 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -31,5 +31,6 @@ 'rnn', 'fusion_lstm', 'softmax_with_cross_entropy', + 'svd', 'class_center_sample', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 15ba331e9de5a..29374a9796504 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -41,6 +41,7 @@ 'elementwise_min', 'elementwise_mul', 'elementwise_sub', + 'elementwise_pow', 'filter_by_instag', 'fused_elemwise_activation', 'fused_emb_seq_pool', diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 929a9696d1c12..2b3383239a0ce 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -46,6 +46,7 @@ 'cudnn_lstm', \ 'rnn', \ 'lgamma', \ + 'svd', \ 'matrix_power', \ ] diff --git a/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py new file mode 100644 index 0000000000000..5a827c1beb291 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +class TestLabelSmoothOp(XPUOpTest): + def config(self): + self.op_type = "label_smooth" + self.epsilon = 0.1 + self.use_xpu = True + batch_size, self.label_dim = 10, 12 + self.label = np.zeros((batch_size, self.label_dim)).astype("float32") + nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) + self.label[np.arange(batch_size), nonzero_index] = 1 + + def setUp(self): + self.config() + smoothed_label = (1 - self.epsilon + ) * self.label + self.epsilon / self.label_dim + self.inputs = {'X': self.label} + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': smoothed_label} + + def test_check_output(self): + if not paddle.is_compiled_with_xpu(): + return + self.check_output_with_place(paddle.XPUPlace(0), atol=1e-6) + + def test_check_grad(self): + return + + +class TestLabelSmoothOp3D(TestLabelSmoothOp): + def setUp(self): + super(TestLabelSmoothOp3D, self).setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]]) + self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X'] + .shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index 63fc36efc29c7..7bdd50c5523f4 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -48,7 +48,7 @@ def dispatch(self, varlist): class HashName(PSDispatcher): """ - :api_attr: Static Graph + :api_attr: Static Graph Hash variable names to several endpoints using python "hash()" function. @@ -90,7 +90,7 @@ def dispatch(self, varlist): class RoundRobin(PSDispatcher): """ - :api_attr: Static Graph + :api_attr: Static Graph Distribute variables to several endpoints using RondRobin method. @@ -110,7 +110,7 @@ class RoundRobin(PSDispatcher): """ def __init__(self, pserver_endpoints): - super(self.__class__, self).__init__(pserver_endpoints) + super(RoundRobin, self).__init__(pserver_endpoints) def dispatch(self, varlist): """ diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index 45db83f9141a7..8d581f38e9b01 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -147,6 +147,8 @@ def forward(self, inputs): input_size = [] for key in input.keys(): input_size.append(tuple(input[key].shape)) + elif isinstance(input, paddle.fluid.framework.Variable): + input_size = tuple(input.shape) else: raise ValueError( "Input is not tensor, list, tuple and dict, unable to determine input_size, please input input_size." diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index ec6b7aa9e3d82..236150eef9479 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,10 +16,12 @@ from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor.linalg import svd __all__ = [ 'cholesky', #noqa 'norm', 'inv', + 'svd', 'matrix_power' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 040bec2f67b9e..375375c8604de 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -45,6 +45,7 @@ from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import svd # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -223,6 +224,7 @@ 'histogram', 'mv', 'matrix_power', + 'svd', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 74d9876cddd5c..40dfd32b50a05 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -943,6 +943,70 @@ def __check_input(x, vec): return out +def svd(x, full_matrices=False, name=None): + r""" + Computes the singular value decomposition of one + matrix or batches of regular matrice. + Args: + x (Tensor): The input tensor. Its shape should be `[..., N, M]`, + where ... is zero or more batch dimensions. N and M can be arbitraty + positive number. Note that if x is sigular matrices, the grad is numerical + instability. The data type of x should be float32 or float64. + + full_matrices(bool): A flag to control the behavor of svd. + If full_matrices = True, svd op will compute full U and V matrics, + which means shape of U is `[..., N, N]`, shape of V is `[..., M, M]`. + If full_matrices = False, svd op will use a economic method to store U and V. + which means shape of U is `[..., N, K]`, shape of V is `[..., M, K]` + + Returns: + Tensor: Tensor U, the shape of U is controlled by full_matrices flag. + Tensor: Tensor S, the singular value of X. the shape of S is [..., K] + Tensor: Tensor VH, the conjugate transpose of V. the shape of V is controlled by full_matrices flag. + + import numpy as np + + x = paddle.to_tensor([[1.0, 2.0], [1.0, 3.0], [4.0, 6.0]]).astype('float64') + x = x.reshape([3, 2]) + u, s, vt = paddle.linalg.svd(x) + print (u) + print (s) + print (vt) + + #U = [[ 0.27364809, -0.21695147 ], + # [ 0.37892198, -0.87112408 ], + # [ 0.8840446 , 0.44053933 ]] + + #S = [8.14753743, 0.78589688] + + #VT= [[ 0.51411221, 0.85772294], + # [ 0.85772294, -0.51411221]] + + # one can verify : U * S * VT = X ; + # U * UH = I ; + # V * VH = I + """ + + if in_dygraph_mode(): + return _C_ops.svd(x, 'full_matrices', full_matrices) + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svd') + check_type(full_matrices, 'full_matrices', bool, 'svd') + helper = LayerHelper('svd', **locals()) + u = helper.create_variable_for_type_inference(dtype=x.dtype) + vh = helper.create_variable_for_type_inference(dtype=x.dtype) + s = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = dict() + attrs['full_matrices'] = full_matrices + helper.append_op( + type='svd', + inputs={'X': [x]}, + outputs={'U': u, + 'VH': vh, + 'S': s}, + attr=attrs, ) + return u, s, vh + + def matrix_power(x, n, name=None): r""" Computes the n-th power of a square matrix or a batch of square matrices. diff --git a/python/paddle/tensor/to_string.py b/python/paddle/tensor/to_string.py index e42bb8f95f21c..f640882893034 100644 --- a/python/paddle/tensor/to_string.py +++ b/python/paddle/tensor/to_string.py @@ -34,15 +34,18 @@ class PrintOptions(object): def set_printoptions(precision=None, threshold=None, edgeitems=None, - sci_mode=None): + sci_mode=None, + linewidth=None): """Set the printing options for Tensor. NOTE: The function is similar with numpy.set_printoptions() Args: precision (int, optional): Number of digits of the floating number, default 8. threshold (int, optional): Total number of elements printed, default 1000. - edgeitems (int, optional): Number of elements in summary at the begining and end of each dimension, defalt 3. + edgeitems (int, optional): Number of elements in summary at the begining and ending of each dimension, default 3. sci_mode (bool, optional): Format the floating number with scientific notation or not, default False. + linewidth (int, optional): Number of characters each line, default 80. + Returns: None. @@ -82,15 +85,18 @@ def set_printoptions(precision=None, check_type(edgeitems, 'edgeitems', (int), 'set_printoptions') DEFAULT_PRINT_OPTIONS.edgeitems = edgeitems kwargs['edgeitems'] = edgeitems + if linewidth is not None: + check_type(linewidth, 'linewidth', (int), 'set_printoptions') + DEFAULT_PRINT_OPTIONS.linewidth = linewidth + kwargs['linewidth'] = linewidth if sci_mode is not None: check_type(sci_mode, 'sci_mode', (bool), 'set_printoptions') DEFAULT_PRINT_OPTIONS.sci_mode = sci_mode kwargs['sci_mode'] = sci_mode - #TODO(zhiqiu): support linewidth core.set_printoptions(**kwargs) -def _to_sumary(var): +def _to_summary(var): edgeitems = DEFAULT_PRINT_OPTIONS.edgeitems # Handle tensor of shape contains 0, like [0, 2], [3, 0, 3] @@ -109,9 +115,9 @@ def _to_sumary(var): if var.shape[0] > 2 * edgeitems: begin = [x for x in var[:edgeitems]] end = [x for x in var[(-1 * edgeitems):]] - return np.stack([_to_sumary(x) for x in (begin + end)]) + return np.stack([_to_summary(x) for x in (begin + end)]) else: - return np.stack([_to_sumary(x) for x in var]) + return np.stack([_to_summary(x) for x in var]) def _format_item(np_var, max_width=0, signed=False): @@ -140,6 +146,7 @@ def _format_item(np_var, max_width=0, signed=False): def _get_max_width(var): + # return max_width for a scalar max_width = 0 signed = False for item in list(var.flatten()): @@ -151,15 +158,30 @@ def _get_max_width(var): return max_width, signed -def _format_tensor(var, sumary, indent=0, max_width=0, signed=False): +def _format_tensor(var, summary, indent=0, max_width=0, signed=False): + """ + Format a tensor + + Args: + var(Tensor): The tensor to be formatted. + summary(bool): Do summary or not. If true, some elements will not be printed, and be replaced with "...". + indent(int): The indent of each line. + max_width(int): The max width of each elements in var. + signed(bool): Print +/- or not. + """ edgeitems = DEFAULT_PRINT_OPTIONS.edgeitems + linewidth = DEFAULT_PRINT_OPTIONS.linewidth if len(var.shape) == 0: # currently, shape = [], i.e., scaler tensor is not supported. # If it is supported, it should be formatted like this. return _format_item(var, max_width, signed) elif len(var.shape) == 1: - if sumary and var.shape[0] > 2 * edgeitems: + item_length = max_width + 2 + items_per_line = (linewidth - indent) // item_length + items_per_line = max(1, items_per_line) + + if summary and var.shape[0] > 2 * edgeitems: items = [ _format_item(item, max_width, signed) for item in list(var)[:edgeitems] @@ -171,21 +193,26 @@ def _format_tensor(var, sumary, indent=0, max_width=0, signed=False): items = [ _format_item(item, max_width, signed) for item in list(var) ] - s = ', '.join(items) + lines = [ + items[i:i + items_per_line] + for i in range(0, len(items), items_per_line) + ] + s = (',\n' + ' ' * + (indent + 1)).join([', '.join(line) for line in lines]) return '[' + s + ']' else: # recursively handle all dimensions - if sumary and var.shape[0] > 2 * edgeitems: + if summary and var.shape[0] > 2 * edgeitems: vars = [ - _format_tensor(x, sumary, indent + 1, max_width, signed) + _format_tensor(x, summary, indent + 1, max_width, signed) for x in var[:edgeitems] ] + ['...'] + [ - _format_tensor(x, sumary, indent + 1, max_width, signed) + _format_tensor(x, summary, indent + 1, max_width, signed) for x in var[(-1 * edgeitems):] ] else: vars = [ - _format_tensor(x, sumary, indent + 1, max_width, signed) + _format_tensor(x, summary, indent + 1, max_width, signed) for x in var ] @@ -211,14 +238,14 @@ def to_string(var, prefix='Tensor'): for dim in var.shape: size *= dim - sumary = False + summary = False if size > DEFAULT_PRINT_OPTIONS.threshold: - sumary = True + summary = True - max_width, signed = _get_max_width(_to_sumary(np_var)) + max_width, signed = _get_max_width(_to_summary(np_var)) data = _format_tensor( - np_var, sumary, indent=indent, max_width=max_width, signed=signed) + np_var, summary, indent=indent, max_width=max_width, signed=signed) return _template.format( prefix=prefix, diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index f90ff0c99af95..037601cd083c2 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -662,6 +662,12 @@ def _get_param_from_state_dict(state_dict): np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) def test_summary_input(self): + paddle.enable_static() + mymodel = MyModel() + input_data = paddle.rand([1, 20]) + paddle.summary(mymodel, input=input_data) + paddle.disable_static() + rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional') input_data = paddle.rand([4, 23, 16]) paddle.summary(rnn, input=input_data) diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index c09748913f9da..1a3dbd68066a7 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -309,7 +309,14 @@ class ToTensor(BaseTransform): data_format (str, optional): Data format of output tensor, should be 'HWC' or 'CHW'. Default: 'CHW'. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray): The input image with shape (H x W x C). + - output(np.ndarray): A tensor with shape (C x H x W) or (H x W x C) according option data_format. + + Returns: + A callable object of ToTensor. + Examples: .. code-block:: python @@ -368,6 +375,13 @@ class Resize(BaseTransform): - "lanczos": cv2.INTER_LANCZOS4 keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A resized image. + + Returns: + A callable object of Resize. + Examples: .. code-block:: python @@ -422,6 +436,13 @@ class RandomResizedCrop(BaseTransform): - "lanczos": cv2.INTER_LANCZOS4 keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A cropped image. + + Returns: + A callable object of RandomResizedCrop. + Examples: .. code-block:: python @@ -503,6 +524,13 @@ class CenterCrop(BaseTransform): size (int|list|tuple): Target size of output image, with (height, width) shape. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A cropped image. + + Returns: + A callable object of CenterCrop. + Examples: .. code-block:: python @@ -537,6 +565,13 @@ class RandomHorizontalFlip(BaseTransform): prob (float, optional): Probability of the input data being flipped. Should be in [0, 1]. Default: 0.5 keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A horiziotal flipped image. + + Returns: + A callable object of RandomHorizontalFlip. + Examples: .. code-block:: python @@ -571,6 +606,13 @@ class RandomVerticalFlip(BaseTransform): prob (float, optional): Probability of the input data being flipped. Default: 0.5 keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A vertical flipped image. + + Returns: + A callable object of RandomVerticalFlip. + Examples: .. code-block:: python @@ -579,7 +621,7 @@ class RandomVerticalFlip(BaseTransform): from PIL import Image from paddle.vision.transforms import RandomVerticalFlip - transform = RandomVerticalFlip(224) + transform = RandomVerticalFlip() fake_img = Image.fromarray((np.random.rand(300, 320, 3) * 255.).astype(np.uint8)) @@ -612,7 +654,14 @@ class Normalize(BaseTransform): 'CHW'. Default: 'CHW'. to_rgb (bool, optional): Whether to convert to rgb. Default: False. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A normalized array or tensor. + + Returns: + A callable object of Normalize. + Examples: .. code-block:: python @@ -665,7 +714,15 @@ class Transpose(BaseTransform): Args: order (list|tuple, optional): Target order of input data. Default: (2, 0, 1). keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(np.ndarray|Paddle.Tensor): A transposed array or tensor. If input + is a PIL.Image, output will be converted to np.ndarray automatically. + + Returns: + A callable object of Transpose. + Examples: .. code-block:: python @@ -707,6 +764,13 @@ class BrightnessTransform(BaseTransform): non negative number. 0 gives the original image keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): An image with a transform in brghtness. + + Returns: + A callable object of BrightnessTransform. + Examples: .. code-block:: python @@ -743,6 +807,13 @@ class ContrastTransform(BaseTransform): non negative number. 0 gives the original image keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): An image with a transform in contrast. + + Returns: + A callable object of ContrastTransform. + Examples: .. code-block:: python @@ -781,6 +852,13 @@ class SaturationTransform(BaseTransform): non negative number. 0 gives the original image keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): An image with a transform in saturation. + + Returns: + A callable object of SaturationTransform. + Examples: .. code-block:: python @@ -817,6 +895,13 @@ class HueTransform(BaseTransform): between 0 and 0.5, 0 gives the original image keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): An image with a transform in hue. + + Returns: + A callable object of HueTransform. + Examples: .. code-block:: python @@ -860,6 +945,13 @@ class ColorJitter(BaseTransform): Chosen uniformly from [-hue, hue]. Should have 0<= hue <= 0.5. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A color jittered image. + + Returns: + A callable object of ColorJitter. + Examples: .. code-block:: python @@ -938,7 +1030,14 @@ class RandomCrop(BaseTransform): pad_if_needed (boolean|optional): It will pad the image if smaller than the desired size to avoid raising an exception. Default: False. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A random cropped image. + + Returns: + A callable object of RandomCrop. + Examples: .. code-block:: python @@ -1040,7 +1139,14 @@ class Pad(BaseTransform): padding ``[1, 2, 3, 4]`` with 2 elements on both sides in symmetric mode will result in ``[2, 1, 1, 2, 3, 4, 4, 3]``. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A paded image. + + Returns: + A callable object of Pad. + Examples: .. code-block:: python @@ -1113,7 +1219,14 @@ class RandomRotation(BaseTransform): Origin is the upper left corner. Default is the center of the image. keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): A rotated image. + + Returns: + A callable object of RandomRotation. + Examples: .. code-block:: python @@ -1180,11 +1293,15 @@ class Grayscale(BaseTransform): Args: num_output_channels (int): (1 or 3) number of channels desired for output image keys (list[str]|tuple[str], optional): Same as ``BaseTransform``. Default: None. - + + Shape: + - img(PIL.Image|np.ndarray|Paddle.Tensor): The input image with shape (H x W x C). + - output(PIL.Image|np.ndarray|Paddle.Tensor): Grayscale version of the input image. + - If output_channels == 1 : returned image is single channel + - If output_channels == 3 : returned image is 3 channel with r == g == b + Returns: - CV Image: Grayscale version of the input. - - If output_channels == 1 : returned image is single channel - - If output_channels == 3 : returned image is 3 channel with r == g == b + A callable object of Grayscale. Examples: diff --git a/tools/get_pr_ut.py b/tools/get_pr_ut.py index 9333797839349..bd67d68c13111 100644 --- a/tools/get_pr_ut.py +++ b/tools/get_pr_ut.py @@ -328,7 +328,7 @@ def get_pr_ut(self): if f_judge.endswith('.md'): ut_list.append('md_placeholder') onlyCommentsFilesOrXpu.append(f_judge) - elif 'tests/unittests/xpu' in f_judge or 'tests/unittests/npu' in f_judge: + elif 'tests/unittests/xpu' in f_judge or 'tests/unittests/npu' in f_judge or 'op_npu.cc' in f_judge: ut_list.append('xpu_npu_placeholder') onlyCommentsFilesOrXpu.append(f_judge) elif f_judge.endswith(('.h', '.cu', '.cc', 'py')):