diff --git a/cmake/operators.cmake b/cmake/operators.cmake index cdc39161bde25..2c010a1e6297f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -185,6 +185,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") + list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 0c66622ed7b9a..f4660751b582a 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -216,6 +216,7 @@ void HogwildWorker::TrainFiles() { // how to accumulate fetched values here device_reader_->Start(); int cur_batch; + int batch_cnt = 0; while ((cur_batch = device_reader_->Next()) > 0) { for (auto &op : ops_) { bool need_skip = false; @@ -230,13 +231,26 @@ void HogwildWorker::TrainFiles() { } } + if (need_dump_field_) { + DumpField(*thread_scope_, dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*thread_scope_, batch_cnt); + } + total_ins_num += cur_batch; + ++batch_cnt; PrintFetchVars(); thread_scope_->DropKids(); } timeline.Pause(); VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec() << " seconds, ins_num: " << total_ins_num; + + if (need_dump_field_ || need_dump_param_) { + writer_.Flush(); + } + #if defined PADDLE_WITH_PSCORE if (thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index c0ccc196348a5..2a022ea4bb9ef 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -214,7 +214,7 @@ void MultiTrainer::Finalize() { if (need_dump_field_ || need_dump_param_) { FinalizeDumpEnv(); } -#ifdef PADDLE_WITH_HETERPS + for (size_t i = 0; i < need_merge_var_names_.size(); i++) { Variable* root_var = root_scope_->FindVar(need_merge_var_names_[i]); if (root_var == nullptr) { @@ -222,7 +222,11 @@ void MultiTrainer::Finalize() { } LoDTensor* root_tensor = root_var->GetMutable(); +#ifdef PADDLE_WITH_HETERPS for (size_t j = 0; j < places_.size(); j++) { +#else + for (int j = 1; j < thread_num_; j++) { +#endif Scope* cur_thread_scope = workers_[j]->GetThreadScope(); Variable* thread_var = cur_thread_scope->FindVar(need_merge_var_names_[i]); @@ -246,8 +250,8 @@ void MultiTrainer::Finalize() { _ForEachDataType_(MergeCallback); } } +#ifdef PADDLE_WITH_HETERPS MergeDenseParam(); - #endif root_scope_->DropKids(); } diff --git a/paddle/fluid/inference/tensorrt/convert/gather_op.cc b/paddle/fluid/inference/tensorrt/convert/gather_op.cc index 346a8bffa00e3..e7b82388b6ab8 100644 --- a/paddle/fluid/inference/tensorrt/convert/gather_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/gather_op.cc @@ -41,33 +41,27 @@ class GatherOpConverter : public OpConverter { std::string input_name = op_desc.Input("X").front(); std::string index_name = op_desc.Input("Index").front(); std::string output_name = op_desc.Output("Out").front(); - const auto input_tensor = engine_->GetITensor(input_name); const auto index_tensor = engine_->GetITensor(index_name); - const int axis = 0; + int axis = 0; + if (op_desc.HasAttr("axis")) { + axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); + } - auto layer = TRT_ENGINE_ADD_LAYER(engine_, Gather, *input_tensor, - *index_tensor, axis); + auto reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *index_tensor); - auto odim = layer->getOutput(0)->getDimensions(); + nvinfer1::Dims index_shape{}; + index_shape.nbDims = 1; + index_shape.d[0] = -1; - auto reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); + reshape_layer->setReshapeDimensions(index_shape); - nvinfer1::Dims target_shape{}; - target_shape.nbDims = odim.nbDims - 1; - for (int i = 0; i < axis; ++i) { - target_shape.d[i] = odim.d[i]; - } - target_shape.d[axis] = 0; - for (int i = axis + 1; i < target_shape.nbDims; ++i) { - target_shape.d[i] = odim.d[i + 1]; - } - - reshape_layer->setReshapeDimensions(target_shape); + auto layer = TRT_ENGINE_ADD_LAYER(engine_, Gather, *input_tensor, + *reshape_layer->getOutput(0), axis); + layer->setNbElementWiseDims(0); - RreplenishLayerAndOutput(reshape_layer, "gather", {output_name}, test_mode); + RreplenishLayerAndOutput(layer, "gather", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ac280dd160776..75f5616f7584f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -362,9 +362,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "gather") { - if (!with_dynamic_shape) return false; - - if (with_dynamic_shape) { + auto gather_inputs = desc.Inputs(); + if (gather_inputs.find("Axis") != gather_inputs.end()) { + if (desc.Input("Axis").size() >= 1) { + return false; + } + } + if (!with_dynamic_shape) { + return false; + } else { auto* block = desc.Block(); auto* x_var_desc = block->FindVar(desc.Input("X")[0]); const auto x_shape = x_var_desc->GetShape(); @@ -373,13 +379,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } } - - auto inputs = desc.InputArgumentNames(); - for (auto& input : inputs) { - if (input == "Axis" && desc.Input("Axis").size() > 0) return false; - } - // current not support axis from input, use default 0 - if (desc.GetAttrIfExists("axis")) return false; } if (op_type == "gather_nd") { @@ -1085,13 +1084,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, #if IS_TRT_VERSION_GE(7000) if (op_type == "tile") { // Paddle-TRT does not support the input tensors. - auto inputs = desc.InputArgumentNames(); - for (auto& input : inputs) { - if (input == "repeat_times_tensor" && - desc.Input("repeat_times_tensor").size() > 0) + auto tile_inputs = desc.Inputs(); + if (tile_inputs.find("repeat_times_tensor") != tile_inputs.end()) { + if (desc.Input("repeat_times_tensor").size() >= 1) { return false; - if (input == "RepeatTimes" && desc.Input("RepeatTimes").size() > 0) + } + } + if (tile_inputs.find("RepeatTimes") != tile_inputs.end()) { + if (desc.Input("RepeatTimes").size() >= 1) { return false; + } } if (with_dynamic_shape) return false; if (!with_dynamic_shape && !desc.HasAttr("repeat_times")) return false; diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc new file mode 100644 index 0000000000000..b3056bd43ba53 --- /dev/null +++ b/paddle/fluid/operators/eigh_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigh_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class EighOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", + "Eigh"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", + "Eigh"); + + auto input_dim = ctx->GetInputDim("X"); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE(rank, 2, + platform::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions." + "But received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + input_dim[rank - 2], input_dim[rank - 1], + platform::errors::InvalidArgument( + "Eigh op is designed for square matrix, consequently" + "inner-most 2 dimensions of Input(X) should be symmetric." + "But received X's shape[-2] = %d and shape[-1] = %d.", + input_dim[rank - 2], input_dim[rank - 1])); + + std::vector values_dim; + if (rank > 2) { + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); + } + } else { + values_dim = {input_dim[1]}; + } + + ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); + ctx->SetOutputDim("Eigenvectors", input_dim); + } +}; + +class EignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), Hermitian or real symmetric matrices." + "Its shape should be [*, N, N] where * is zero or" + "more batch dimensions. The data type is float32 ," + "float64, complex64, complex128."); + AddOutput("Eigenvalues", + "(Tensor), The eigenvalues in ascending order." + "The data type is float32 or float64."); + AddOutput( + "Eigenvectors", + "(Tensor), The column is the normalized eigenvector " + "corresponding to the eigenvalue. The data type is the same as ``X``."); + AddAttr( + "UPLO", + "(string, default 'L'), 'L' represents the lower triangular matrix," + "'U' represents the upper triangular matrix.") + .SetDefault("L"); + AddComment(R"DOC( +Eigh Operator. + +Computes the eigenvalues and eigenvectors of a complex Hermitian + (conjugate symmetric) or a real symmetric matrix. + +)DOC"); + } +}; + +class EighGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", + "EighGrad"); + OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", + "EighGrad"); + OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), + "Input", "Eigenvalues@GRAD", "EighGrad"); + OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), + "Input", "Eigenvectors@GRAD", "EighGrad"); + auto dims = ctx->GetInputDim("Eigenvectors"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Eigenvectors")), + ctx.device_context()); + } +}; + +template +class EighGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Eigenvalues", this->Output("Eigenvalues")); + op->SetInput("Eigenvectors", this->Output("Eigenvectors")); + op->SetInput(framework::GradVarName("Eigenvalues"), + this->OutputGrad("Eigenvalues")); + op->SetInput(framework::GradVarName("Eigenvectors"), + this->OutputGrad("Eigenvectors")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker, + ops::EighGradOpMaker, + ops::EighGradOpMaker); +REGISTER_OPERATOR(eigh_grad, ops::EighGradOp); + +REGISTER_OP_CPU_KERNEL( + eigh, ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, + ops::EighKernel>); + +REGISTER_OP_CPU_KERNEL( + eigh_grad, + ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, + ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu new file mode 100644 index 0000000000000..cfc9eba450959 --- /dev/null +++ b/paddle/fluid/operators/eigh_op.cu @@ -0,0 +1,53 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigh_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class EighGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto input_var = ctx.Input("X"); + auto output_w_var = ctx.Output("Eigenvalues"); + auto output_v_var = ctx.Output("Eigenvectors"); + std::string lower = ctx.Attr("UPLO"); + bool is_lower = (lower == "L"); + math::MatrixEighFunctor functor; + functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + eigh, ops::EighGPUKernel, ops::EighGPUKernel, + ops::EighGPUKernel>, + ops::EighGPUKernel>); + +REGISTER_OP_CUDA_KERNEL( + eigh_grad, + ops::EighGradKernel, + ops::EighGradKernel, + ops::EighGradKernel>, + ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h new file mode 100644 index 0000000000000..0af38d44e5457 --- /dev/null +++ b/paddle/fluid/operators/eigh_op.h @@ -0,0 +1,80 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/eigen_values_vectors.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; +template +using EigenVector = framework::EigenVector; + +template +class EighKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input_var = ctx.Input("X"); + auto output_w_var = ctx.Output("Eigenvalues"); + auto output_v_var = ctx.Output("Eigenvectors"); + std::string lower = ctx.Attr("UPLO"); + bool is_lower = (lower == "L"); + math::MatrixEighFunctorCPU functor; + functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + } +}; + +template +class EighGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& x_grad = *ctx.Output(framework::GradVarName("X")); + x_grad.mutable_data(ctx.GetPlace()); + auto& output_w_var = *ctx.Input("Eigenvalues"); + auto& output_v_var = *ctx.Input("Eigenvectors"); + auto& output_w_grad = + *ctx.Input(framework::GradVarName("Eigenvalues")); + auto& output_v_grad = + *ctx.Input(framework::GradVarName("Eigenvectors")); + + auto& dims = output_v_var.dims(); + const int m = dims[dims.size() - 1]; + auto dito = + math::DeviceIndependenceTensorOperations( + ctx); + auto tV = dito.Transpose(dito.Conj(output_v_var)); + auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), + dito.Unsqueeze(output_w_var, -1)); + Tensor result = dito.Matmul(tV, output_v_grad); + result.mutable_data(dims, ctx.GetPlace()); + std::vector out_shape = framework::vectorize(dims); + auto constant = dito.Fill(out_shape, 0.5); + result = dito.Sub(result, dito.Conj(dito.Transpose(result))); + result = dito.Mul(result, constant); + result = dito.Div_(result, W); + result = dito.DiagFill(m, m, m, 0, output_w_grad, result); + x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index e076444626e6a..2c7fd8f4173ea 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -66,6 +66,12 @@ class GroupNormOp : public framework::OperatorWithKernel { "The Attr(groups) of Op(group_norm) must be " "greater than or equal to 1. But received: groups is [%s].", groups)); + PADDLE_ENFORCE_EQ( + channel_num % groups, 0, + platform::errors::InvalidArgument( + "Expected number of channels in input to be divisible by " + "num_groups, but got input channel is %d and num_groups is %d", + channel_num, groups)); if (ctx->HasInput("Scale")) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index f199bfeb9443b..e029c84090af1 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -144,7 +144,8 @@ class GroupNormKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; + const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); @@ -314,7 +315,7 @@ class GroupNormGradKernel const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] : x_dims[x_dims.size() - 2]); diff --git a/paddle/fluid/operators/group_norm_op.h b/paddle/fluid/operators/group_norm_op.h index f2388699e266f..9cb451235f152 100644 --- a/paddle/fluid/operators/group_norm_op.h +++ b/paddle/fluid/operators/group_norm_op.h @@ -52,7 +52,7 @@ class GroupNormKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; y->mutable_data(ctx.GetPlace()); mean->mutable_data(ctx.GetPlace()); @@ -100,7 +100,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -138,7 +138,7 @@ class GroupNormKernel : public framework::OpKernel { int imid; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M * C) { - // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used + // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used // in template class/function, before we complete high // performance cpu vector extension, temporarily unrolling // loop to get high precision and performance @@ -236,7 +236,7 @@ class GroupNormGradKernel : public framework::OpKernel { const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); - const int group_size = (C - 1) / groups + 1; + const int group_size = C / groups; d_x->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; diff --git a/paddle/fluid/operators/index_select_op_npu.cc b/paddle/fluid/operators/index_select_op_npu.cc index 8df6c4e5d9ea7..b624d03cc8555 100644 --- a/paddle/fluid/operators/index_select_op_npu.cc +++ b/paddle/fluid/operators/index_select_op_npu.cc @@ -21,12 +21,12 @@ namespace operators { template class IndexSelectNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Index"); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* index = ctx.Input("Index"); auto dim = ctx.Attr("dim"); - auto *out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto stream = @@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel { } }; -// todo: add class 'IndexSelectGradNPUKernel' here. +template +class IndexSelectGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x_grad = ctx.Output(framework::GradVarName("X")); + auto* index = ctx.Input("Index"); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + + auto stream = + ctx.template device_context() + .stream(); + + auto x_dims = x_grad->dims(); + auto out_dims = out_grad->dims(); + + int dim = ctx.Attr("dim"); + if (dim < 0) { + dim += out_dims.size(); + } + + Tensor casted_index; + if (index->type() != framework::proto::VarType::INT32) { + casted_index.mutable_data(index->dims(), ctx.GetPlace()); + const auto& cast_runner = NpuOpRunner("Cast", {*index}, {casted_index}, + {{"dst_type", ACL_INT32}}); + cast_runner.Run(stream); + } else { + casted_index.ShareDataWith(*index); + } + + if (dim == 0) { + x_grad->mutable_data(ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {*x_grad}, {*x_grad}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(*out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(*x_grad); + runner.Run(stream); + } else { + Tensor transed_out_grad; + std::vector in_trans_perm; + in_trans_perm.push_back(dim); + for (int i = 0; i < out_dims.size(); ++i) { + if (i == dim) continue; + in_trans_perm.push_back(i); + } + framework::DDim transed_out_dims(out_dims); + for (size_t i = 0; i < in_trans_perm.size(); ++i) { + transed_out_dims[i] = out_dims[in_trans_perm[i]]; + } + transed_out_grad.mutable_data(transed_out_dims, ctx.GetPlace()); + framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}}; + + const auto& in_trans_runner = NpuOpRunner( + "TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr); + in_trans_runner.Run(stream); + + Tensor sum_out; + framework::DDim sum_dims(x_dims); + sum_dims[0] = x_dims[dim]; + auto idx = 1; + for (int i = 0; i < x_dims.size(); ++i) { + if (i == dim) continue; + sum_dims[idx++] = x_dims[i]; + } + sum_out.mutable_data(sum_dims, ctx.GetPlace()); + const auto& zeros_runner = NpuOpRunner("ZerosLike", {sum_out}, {sum_out}); + zeros_runner.Run(stream); + + NpuOpRunner runner; + runner.SetType("UnsortedSegmentSum") + .AddInput(transed_out_grad) + .AddInput(casted_index) + .AddInput(std::vector{x_dims[dim]}) + .AddOutput(sum_out); + runner.Run(stream); + + std::vector out_trans_perm; + for (int i = 1; i < 1 + dim; ++i) { + out_trans_perm.push_back(i); + } + out_trans_perm.push_back(0); + for (int i = 1 + dim; i < x_dims.size(); ++i) { + out_trans_perm.push_back(i); + } + framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}}; + x_grad->mutable_data(ctx.GetPlace()); + const auto& out_trans_runner = + NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr); + out_trans_runner.Run(stream); + } + } +}; } // namespace operators } // namespace paddle @@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL( ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel, ops::IndexSelectNPUKernel); -// todo: register npu index_select_grad kernel here. +REGISTER_OP_NPU_KERNEL( + index_select_grad, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel, + ops::IndexSelectGradNPUKernel); diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h new file mode 100644 index 0000000000000..4e2d180e33628 --- /dev/null +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -0,0 +1,314 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "Eigen/Core" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/svd_helper.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/dynload/cusolver.h" +#endif // PADDLE_WITH_CUDA + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenTensor = framework::EigenTensor; + +template +using InputMatrixMap = Eigen::Map< + const Eigen::Matrix>; + +template +using OutputMatrixMap = Eigen::Map< + Eigen::Matrix>; + +template +inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, + ValueType *eigenvalues_data, + ValueType *eigenvectors_data, + int batches, int rows, int cols, + bool has_vectors) { + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = InputMatrixMap(x_data + i * stride, rows, cols); + auto eigenvalues = + OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); + auto eigenvectors = + OutputMatrixMap(eigenvectors_data + i * stride, rows, cols); + + Eigen::SelfAdjointEigenSolver> + eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors + : Eigen::EigenvaluesOnly); + PADDLE_ENFORCE_EQ( + eigen_solver.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Self Adjoint Eigen decomposition is not successful. " + "The %d-th input matrice might not be not be positive definite.", + i)); + + eigenvalues = eigen_solver.eigenvalues().transpose(); + if (has_vectors) { + eigenvectors = eigen_solver.eigenvectors().transpose(); + } + } +} + +template +inline void ComputeComplexEigenvaluesAndVectors(T *x_data, + ValueType *eigenvalues_data, + T *eigenvectors_data, + int batches, int rows, int cols, + bool has_vectors) { + using Complex = std::complex; + Complex *input = reinterpret_cast(x_data); + Complex *eigenvectors_data_ = reinterpret_cast(eigenvectors_data); + + int stride = rows * cols; + for (int i = 0; i < batches; i++) { + auto m = InputMatrixMap(input + i * stride, rows, cols); + auto eigenvalues = + OutputMatrixMap(eigenvalues_data + i * rows, 1, rows); + auto eigenvectors = + OutputMatrixMap(eigenvectors_data_ + i * stride, rows, cols); + + Eigen::SelfAdjointEigenSolver< + Eigen::Matrix> + eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors + : Eigen::EigenvaluesOnly); + PADDLE_ENFORCE_EQ( + eigen_solver.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Self Adjoint Eigen decomposition is not successful. " + "The %d-th input matrice might not be not be positive definite.", + i)); + + eigenvalues = eigen_solver.eigenvalues().transpose(); + if (has_vectors) { + eigenvectors = eigen_solver.eigenvectors().transpose(); + } + } +} + +inline int64_t GetBatchSize(framework::DDim dims) { + int64_t batch_size = 1; + auto dim_size = dims.size(); + for (int i = 0; i < dim_size - 2; i++) { + batch_size *= dims[i]; + } + return batch_size; +} + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices, and uses the variable has_vectors to +// control whether to return the eigenvectors. +template +struct MatrixEighFunctorCPU { + public: + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors) { + auto dims = input.dims(); + auto output_value_dim = eigen_values->dims(); + + int64_t batch_size = 1; + int dim_size = dims.size(); + for (int64_t i = 0; i < dim_size - 2; i++) { + batch_size *= dims[i]; + } + auto dito = DeviceIndependenceTensorOperations(ctx); + Tensor input_tensor; + TensorCopy(input, ctx.GetPlace(), &input_tensor); + if (!is_lower) { + input_tensor = dito.Transpose(input); + } + int rows = dims[dims.size() - 2]; + + auto *value_data = + eigen_values->mutable_data(output_value_dim, ctx.GetPlace()); + + if (framework::IsComplexType(input_tensor.type())) { + auto *x_data = input_tensor.data(); + auto *vector_data = eigen_vectors->mutable_data(dims, ctx.GetPlace()); + ComputeComplexEigenvaluesAndVectors( + x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); + } else { + auto *x_data = input_tensor.data(); + auto *vector_data = + eigen_vectors->mutable_data(dims, ctx.GetPlace()); + ComputeFloatEigenvaluesAndVectors( + x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); + } + if (has_vectors) { + *eigen_vectors = dito.Transpose(*eigen_vectors); + } + } +}; + +#ifdef PADDLE_WITH_CUDA + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices on GPU, and uses the variable has_vectors +// to control whether to return the eigenvectors. +template +struct MatrixEighFunctor { + public: + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors) { + auto *out_value = eigen_values->mutable_data(ctx.GetPlace()); + auto *out_vector = eigen_vectors->mutable_data(ctx.GetPlace()); + + auto &dims = input.dims(); + int dim_size = dims.size(); + int64_t batch_size = GetBatchSize(dims); + + cublasFillMode_t uplo = + is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + cusolverEigMode_t jobz = + has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + + int n = dims[dim_size - 1]; + int lda = std::max(1, n); + auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + auto values_stride = dims[dim_size - 1]; + + auto &dev_ctx = ctx.template device_context(); + auto dito = + math::DeviceIndependenceTensorOperations(ctx); + Tensor output_v_var_trans = dito.Transpose(input); + TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors); + + int lwork = 0; + auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size); + auto *info_ptr = reinterpret_cast(info->ptr()); + + // When the input type is float32, and the feature value input dimension is + // greater than or equal to [*,32,32] and less than or equal to + // [*,512,512], Syevj has better performance. + bool use_syevj = + (eigen_vectors->type() == framework::proto::VarType::FP32 && + values_stride >= 32 && values_stride <= 512); + + syevjInfo_t syevj_params; + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnCreateSyevjInfo(&syevj_params)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnSsyevj_bufferSize( + dev_ctx.cusolver_dn_handle(), jobz, uplo, n, + reinterpret_cast(out_vector), lda, + reinterpret_cast(out_value), &lwork, + syevj_params)); + } else { + EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, out_vector, lda, + out_value, &lwork); + } + + auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork); + auto *work_ptr = reinterpret_cast(work->ptr()); + + for (auto i = 0; i < batch_size; i++) { + auto vector_data = out_vector + i * vector_stride; + auto value_data = out_value + i * values_stride; + auto handle = dev_ctx.cusolver_dn_handle(); + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj( + handle, jobz, uplo, n, reinterpret_cast(vector_data), lda, + reinterpret_cast(value_data), + reinterpret_cast(work_ptr), lwork, info_ptr, + syevj_params)); + } else { + Evd(handle, jobz, uplo, n, vector_data, lda, value_data, work_ptr, + lwork, info_ptr); + } + int error_info; + memory::Copy(platform::CPUPlace(), &error_info, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_ptr, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + error_info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: the [%d] argument had an illegal value", i, + error_info)); + } + + if (use_syevj) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cusolverDnDestroySyevjInfo(syevj_params)); + } + + if (has_vectors) { + *eigen_vectors = dito.Transpose(*eigen_vectors); + } + } + + inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const T *A, int lda, + const ValueType *W, int *lwork) const; + + inline void Evd(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, + T *work, int lwork, int *devInfo) const; +}; + +#define FUNC_WITH_TYPES(m) \ + m(float, float, Ssy, float) m(double, double, Dsy, double) \ + m(float, paddle::platform::complex, Che, cuComplex) \ + m(double, paddle::platform::complex, Zhe, cuDoubleComplex) + +#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::EvdBuffer( \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ + int *lwork) const { \ + PADDLE_ENFORCE_CUDA_SUCCESS( \ + platform::dynload::cusolverDn##C##evd_bufferSize( \ + handle, jobz, uplo, n, reinterpret_cast(A), lda, \ + W, lwork)); \ + } + +FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); + +#define EVD_INSTANCE(ValueType, T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::Evd( \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ + int lwork, int *devInfo) const { \ + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##evd( \ + handle, jobz, uplo, n, reinterpret_cast(A), lda, W, \ + reinterpret_cast(work), lwork, devInfo)); \ + } + +FUNC_WITH_TYPES(EVD_INSTANCE); + +#undef FUNC_WITH_TYPES +#undef EVDBUFFER_INSTANCE +#undef EVD_INSTANCE + +#endif // PADDLE_WITH_CUDA + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.cc b/paddle/fluid/operators/spectral_op.cc index fb0be9ba68fcf..becb9108a28f9 100644 --- a/paddle/fluid/operators/spectral_op.cc +++ b/paddle/fluid/operators/spectral_op.cc @@ -698,7 +698,7 @@ struct FFTC2CFunctor { framework::vectorize(framework::stride(input_dim)); const int64_t data_size = sizeof(C); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](std::ptrdiff_t s) { return s * data_size; }); + [&](std::ptrdiff_t s) { return s * data_size; }); const auto* in_data = reinterpret_cast(x->data()); auto* out_data = reinterpret_cast(out->data()); @@ -732,7 +732,7 @@ struct FFTR2CFunctor { { const int64_t data_size = sizeof(R); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](std::ptrdiff_t s) { return s * data_size; }); + [&](std::ptrdiff_t s) { return s * data_size; }); } const auto& output_dim = out->dims(); @@ -744,7 +744,7 @@ struct FFTR2CFunctor { const int64_t data_size = sizeof(C); std::transform(out_strides.begin(), out_strides.end(), out_strides.begin(), - [](std::ptrdiff_t s) { return s * data_size; }); + [&](std::ptrdiff_t s) { return s * data_size; }); } const auto* in_data = x->data(); @@ -779,7 +779,7 @@ struct FFTC2RFunctor { { const int64_t data_size = sizeof(C); std::transform(in_strides.begin(), in_strides.end(), in_strides.begin(), - [](std::ptrdiff_t s) { return s * data_size; }); + [&](std::ptrdiff_t s) { return s * data_size; }); } const auto& output_dim = out->dims(); @@ -791,7 +791,7 @@ struct FFTC2RFunctor { const int64_t data_size = sizeof(R); std::transform(out_strides.begin(), out_strides.end(), out_strides.begin(), - [](std::ptrdiff_t s) { return s * data_size; }); + [&](std::ptrdiff_t s) { return s * data_size; }); } const auto* in_data = reinterpret_cast(x->data()); diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index b0c361e86a531..71d106c211f71 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -25,6 +25,8 @@ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" @@ -36,6 +38,9 @@ using Tensor = framework::Tensor; using InTensors = std::vector; using OutTensors = std::vector; using OpName = std::string; +template +using EigenVector = framework::EigenVector; template void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols, @@ -140,7 +145,42 @@ static std::vector GetBroadcastShape(InTensors ins) { break; \ } -template +template +struct DiagAndFillFunctor { + DiagAndFillFunctor(const int m, const int n, const int num_lower_diags, + const int num_upper_diags, const ValueType* scale, + const T* input, T* output) + : m_(m), + n_(n), + num_lower_diags_(num_lower_diags), + num_upper_diags_(num_upper_diags), + scale_(scale), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int col = index % n_; + const int row = (index / n_) % m_; + const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_); + const int band_end = + (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1); + if (col < band_start || col >= band_end) { + output_[index] = input_[index]; + } else if (col == band_end - 1) { + output_[index] = static_cast(scale_[index % m_]); + } else { + output_[index] = input_[index]; + } + } + + private: + const int m_, n_, num_lower_diags_, num_upper_diags_; + const ValueType* scale_; + const T* input_; + T* output_; +}; + +template struct DeviceIndependenceTensorOperations { // 1. Device indenpendence, for kernel reuse. // 2. Input and output is always tensor type. @@ -398,6 +438,60 @@ struct DeviceIndependenceTensorOperations { return ret; } + Tensor Conj(const Tensor& x) { + Tensor out; + auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); + auto* x_data = x.data(); + auto for_range = GetForRange(x.numel()); + math::ConjFunctor functor(x_data, x.numel(), out_data); + for_range(functor); + return out; + } + + Tensor DiagFill(const int m, const int n, const int num_lower_diags, + const int num_upper_diags, const Tensor& scale, + const Tensor& input) { + Tensor out; + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, input.numel()); + DiagAndFillFunctor diag_and_copy_functor( + m, n, num_lower_diags, num_upper_diags, scale.data(), + input.data(), out.mutable_data(input.dims(), input.place())); + for_range(diag_and_copy_functor); + return out; + } + + // Support x and y are different data types + Tensor Div_(const Tensor& x, const Tensor& y) { + Tensor out; + out.mutable_data(x.dims(), context.GetPlace()); + auto x_vector = EigenVector::Flatten(x); + auto y_vector = EigenVector::Flatten(y); + auto out_vector = EigenVector::Flatten(out); + auto& place = + *context.template device_context().eigen_device(); + out_vector.device(place) = x_vector / y_vector; + return out; + } + + framework::Tensor Sub_(const framework::Tensor& x, + const framework::Tensor& y) { + framework::Tensor ret; + std::vector out_shape = GetBroadcastShape({&x, &y}); + ret.Resize(framework::make_ddim(out_shape)); + if (x.dims().size() >= y.dims().size()) { + ElementwiseComputeEx, DeviceContext, ValueType>( + context, &x, &y, -1, SubFunctor(), &ret); + } else { + ElementwiseComputeEx, DeviceContext, + ValueType>( + // This is copyed from elementwise_sub, which means we + // need reverse will xrank < yrank + context, &x, &y, -1, InverseSubFunctor(), &ret); + } + return ret; + } + private: const framework::ExecutionContext& context; BlasT GetBlas() { diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 36ba5dd094815..a8ce1cc9d3a35 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -48,7 +48,15 @@ extern void *cusolver_dso_handle; __macro(cusolverDnSpotrf_bufferSize); \ __macro(cusolverDnDpotrf_bufferSize); \ __macro(cusolverDnSpotrf); \ - __macro(cusolverDnDpotrf); + __macro(cusolverDnDpotrf); \ + __macro(cusolverDnSsyevd_bufferSize); \ + __macro(cusolverDnDsyevd_bufferSize); \ + __macro(cusolverDnCheevd_bufferSize); \ + __macro(cusolverDnZheevd_bufferSize); \ + __macro(cusolverDnSsyevd); \ + __macro(cusolverDnDsyevd); \ + __macro(cusolverDnCheevd); \ + __macro(cusolverDnZheevd); CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 614a8e95aa42c..5475e3a51c172 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -94,6 +94,7 @@ from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import transpose # noqa: F401 from .tensor.linalg import dist # noqa: F401 +from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import t # noqa: F401 from .tensor.linalg import cross # noqa: F401 from .tensor.linalg import cholesky # noqa: F401 @@ -102,6 +103,8 @@ from .tensor.linalg import mv # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 +from .tensor.linalg import svd # noqa: F401 +from .tensor.linalg import eigh # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 @@ -256,9 +259,9 @@ from .framework import NPUPlace # noqa: F401 from .framework import CUDAPinnedPlace # noqa: F401 -from .framework import grad # noqa: F401 -from .framework import no_grad # noqa: F401 -from .framework import set_grad_enabled # noqa: F401 +from .autograd import grad # noqa: F401 +from .autograd import no_grad # noqa: F401 +from .autograd import set_grad_enabled # noqa: F401 from .framework import save # noqa: F401 from .framework import load # noqa: F401 from .framework import DataParallel # noqa: F401 diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index 569619f065a05..89094357b3505 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -16,5 +16,7 @@ from . import backward_mode # noqa: F401 from .backward_mode import backward # noqa: F401 from .py_layer import PyLayer, PyLayerContext # noqa: F401 +from ..framework import set_grad_enabled # noqa: F401 +from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 -__all__ = ['grad', 'backward', 'PyLayer', 'PyLayerContext'] +__all__ = ['backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index 3580e85fc89c1..5d28c2d5cebd9 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -142,32 +142,103 @@ def prune_gradient_clip(self, block, shard, ring_ids): return # TODO (JZ-LIANG) revise this for uniform mixed parallelism - def sync_global_norm(self, block, ring_ids): + def sync_global_norm(self, block, ring_ids, mp_rank): """ prune gradient_clip related ops for params that not belong to cur shard prune: square, reduce_sum, elementwise_mul keep: sum, sqrt, elementwise_max, elementwise_div """ - # FIXME(wangxi): mp should prune duplicated param_grads + is_clip_grad_by_global_norm = False + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + is_clip_grad_by_global_norm = True + break + if not is_clip_grad_by_global_norm: + # TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp + return + + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + break + for input_name in op.input_arg_names: + input_var = block.var(input_name) + # NOTE: when mp_degree > 1, some vars will be split into each mp rank. + # However, there still some vars such as Scale, Bias are not split. + # Those not be split vars should only be counted once during grad clip + # by global norm. Those vars either doesn't have is_distributed attr + # or the is_distributed attr has been set as False. + # Therefore, we prune those duplicated vars for grad clip. + if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed') + and input_var.is_distributed)): + removed_op_idx.add(idx) + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) + for idx, op in reversed(list(enumerate(block.ops))): if not self._is_gradient_clip_op(op): continue + if idx in removed_op_idx: + block._remove_op(idx, sync=False) - if op.type == "sum": - sum_res = op.desc.output_arg_names()[0] - for ring_id in ring_ids: - if ring_id == -1: continue + for var_name in removed_tmp_var: + block._remove_var(var_name, sync=False) - idx = idx + 1 - block._insert_op_without_sync( - idx, - type='c_allreduce_sum', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={ - 'ring_id': ring_id, - 'op_namescope': "/gradient_clip_model_parallelism", - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Optimize, - }) - return + for idx, op in list(enumerate(block.ops)): + if not self._is_gradient_clip_op(op): + continue + if op.type == 'sum': + # If mp_rank == 0, no extra handles, just allreduce + # If mp_rank >= 1, some extra handles is needed + sum_rst_var = block.var(op.output_arg_names[0]) + if mp_rank >= 1: + reserved_vars = [] + for input_name in op.input_arg_names: + if input_name not in removed_tmp_var: + reserved_vars.append(input_name) + + if len(reserved_vars) > 0: + op.desc.set_input("X", reserved_vars) + else: + # If all input of sum op should be removed, then remove the sum op. + # And set the output's value of sum to 0. + namescope = op.attr("op_namescope") + block._remove_op(idx, sync=False) + fill_constant_op = block._insert_op_without_sync( + idx, + type='fill_constant', + inputs={}, + outputs={'Out': sum_rst_var}, + attrs={ + 'shape': sum_rst_var.shape, + 'dtype': sum_rst_var.dtype, + 'value': 0.0, + OP_ROLE_KEY: OpRole.Optimize + }) + fill_constant_op._set_attr('op_namescope', namescope) + self._insert_allreduce(block, ring_ids, idx, sum_rst_var) + break + + @staticmethod + def _insert_allreduce(block, ring_ids, idx, var): + for ring_id in ring_ids: + if ring_id == -1: + continue + + idx = idx + 1 + block._insert_op_without_sync( + idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, + }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1f96ab07d60a8..f14f1e0662402 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -435,7 +435,6 @@ def _adapt_amp_clip_without_sharding(self): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - # FIXME(wangxi): mp should prune duplicated param_grads when calc # amp inf_var & clip global_norm_var rings = [self.mp_ring_id, self.pp_ring_id] @@ -446,7 +445,7 @@ def _adapt_amp_clip_without_sharding(self): gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm( - main_block, [self.mp_ring_id, self.pp_ring_id]) + main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank) def _insert_loss_grad_scale_op(self): main_block = self._main_program.global_block() diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index cc98d378f1489..70c7b01b05ba3 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -43,6 +43,36 @@ __all__ = ['get_worker_info'] +# NOTE: fix `terminate called without an active exception` +# if for loop break and program exit immediately(with no model +# layers processing) after iterate **the first few data** in +# distributed lauch mode, distributed launch will call +# terminate() to kill main process on each devices, but thread +# is still iterating to fullfill blocking queue caches, which +# may cause thread error `terminate called without an active +# exception` for terminate is a strong singal and `__del__` +# of DataLoader may not be called, so we add a global link to +# the last DataLoader instance to call `__del__` to clean up +# resources +# NOTE: cannot simply as `__del__` to CleanupFuncRegistrar, +# for this will remain a link to each DataLoader instance in +# global, and will precludes GC to auto collect DataLoader +# instance and will cause memory leak +_loader = None + + +def _clear_loader(): + global _loader + if _loader is not None: + try: + _loader.__del__() + del _loader + except: + pass + + +CleanupFuncRegistrar.register(_clear_loader) + class _DataLoaderIterBase(object): """ @@ -100,6 +130,16 @@ def __iter__(self): def __len__(self): return len(self._batch_sampler) + def _exit_thread_expectedly(self): + self._thread_done_event.set() + if self._blocking_queue: + self._blocking_queue.close() + + def _exit_thread_unexpectedly(self): + self._thread_done_event.set() + if self._blocking_queue: + self._blocking_queue.kill() + class _DataLoaderIterSingleProcess(_DataLoaderIterBase): """ @@ -125,9 +165,13 @@ def __init__(self, loader): # NOTE: len(self._places) batch data compose as an output # iteration, set blocking_queue can cache 2 iteration datas # at most here - self._blocking_queue_capacity = 2 * len(self._places) + self._blocking_queue_capacity = 1 * len(self._places) self._init_thread() + self._shutdown = False + + global _loader + _loader = self def _init_thread(self): self._var_names = [v.name for v in self._feed_list] @@ -151,22 +195,35 @@ def _init_thread(self): self._thread.start() def _thread_loop(self, legacy_expected_place): - try: - #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, - # and it will call platform::SetDeviceId() in c++ internally. - # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0, - # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda - # APIs in this thread. - _set_expected_place(legacy_expected_place) - - for indices in self._sampler_iter: + #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, + # and it will call platform::SetDeviceId() in c++ internally. + # If we do not set cudaDeviceId in new thread, the default cudaDeviceId will be 0, + # Which may cost hundreds of MB of GPU memory on CUDAPlace(0) if calling some cuda + # APIs in this thread. + _set_expected_place(legacy_expected_place) + + while not self._thread_done_event.is_set(): + try: + indices = next(self._sampler_iter) + # read data from dataset in mini-batch - batch = self._dataset_fetcher.fetch(indices) + # with paddle.fluid.dygraph.guard(place=paddle.CPUPlace()): + # read data from dataset in mini-batch + batch = self._dataset_fetcher.fetch(indices, + self._thread_done_event) + except StopIteration: + self._exit_thread_expectedly() + return + + if batch is None or self._thread_done_event.is_set(): break + + # flat batch and record structure infos + batch, structure = _flatten_batch(batch) + self._structure_infos.append(structure) - # flat batch and record structure infos - batch, structure = _flatten_batch(batch) - self._structure_infos.append(structure) + if self._thread_done_event.is_set(): break + try: # pack as LoDTensorArray array = core.LoDTensorArray() for slot in batch: @@ -179,21 +236,18 @@ def _thread_loop(self, legacy_expected_place): array.append(slot) - if not self._blocking_queue.push(array): - break + if self._thread_done_event.is_set(): break - if self._thread_done_event.is_set(): - break + try: + self._blocking_queue.push(array) + except: + self._exit_thread_expectedly() - self._blocking_queue.close() - self._shutdown_thread() - except StopIteration: - self._blocking_queue.close() - except Exception: - self._blocking_queue.kill() - self._shutdown_thread() - logging.warning("DataLoader reader thread raised an exception.") - six.reraise(*sys.exc_info()) + except: + self._exit_thread_unexpectedly() + six.reraise(*sys.exc_info()) + + self._exit_thread_expectedly() def __next__(self): try: @@ -221,28 +275,46 @@ def __next__(self): return data except StopIteration: self._reader.shutdown() + self._try_shutdown_all() six.reraise(*sys.exc_info()) def _shutdown_thread(self): if self._thread: self._thread_done_event.set() - if self._thread is not threading.current_thread(): - self._thread.join() + # NOTE: we wait for _thread exit for 3 seconds, if + # thread not exit normally, force kill it + for _ in range(3): + if self._thread.is_alive(): + time.sleep(1) + else: + break + else: + if self._thread is not threading.current_thread(): + self._thread.join() + self._thread = None # python2 compatibility def next(self): return self.__next__() + def _try_shutdown_all(self): + if not self._shutdown: + try: + # # _blocking_queue in keep order mode holds sub-threads + # # need to release thread resources on unexpected exit + if self._blocking_queue: + self._blocking_queue.close() + self._blocking_queue = None + # NOTE: blocking queue should be closed firstly for + # blocking queue read may hang and _thread_done_event + # cannot be checked + self._shutdown_thread() + finally: + self._shutdown = True + def __del__(self): - # _blocking_queue in keep order mode holds sub-threads - # need to release thread resources on unexpected exit - if self._blocking_queue: - self._blocking_queue.close() - # NOTE: blocking queue should be closed firstly for - # blocking queue read may hang and _thread_done_event - # cannot be checked - self._shutdown_thread() + self._try_shutdown_all() class _DataLoaderIterMultiProcess(_DataLoaderIterBase): @@ -421,15 +493,6 @@ def _try_shutdown_all(self, timeout=None): core._erase_process_pids(id(self)) self._shutdown = True - def _exit_thread_expectedly(self): - self._thread_done_event.set() - self._blocking_queue.close() - - def _exit_thread_unexpectedly(self): - self._thread_done_event.set() - self._blocking_queue.kill() - logging.error("DataLoader reader thread raised an exception!") - def _thread_loop(self, legacy_expected_place): #NOTE(zhiqiu): Set the expected place for new thread as the same as father thread, # and it will call platform::SetDeviceId() in c++ internally. diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 8ccec81810a0a..ec3240a326b8e 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -26,7 +26,16 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): self.collate_fn = collate_fn self.drop_last = drop_last - def fetch(self, batch_indices): + # NOTE: fetch function here perform the whole pipeline of dataset + # reading and data trasforms of a batch in each calling, this + # may take a long time inside, if DataLoader is exit outside, + # fetch need to perceive exit situation, so we pass done_event + # here for fetch to check exit status + # NOTE: if DataLoadet exit by `break`, performing GPU tensor operations, + # e.g. to_tensor may cause SIGSEGV in thread, so we pass the + # done_event argument to check DataLoader exit status between + # ecah sample processing in the batch + def fetch(self, batch_indices, done_event=None): raise NotImplementedError("'fetch' not implement for class {}".format( self.__class__.__name__)) @@ -69,15 +78,18 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): dataset, auto_collate_batch, collate_fn, drop_last) self.dataset_iter = iter(dataset) - def fetch(self, batch_indices): + def fetch(self, batch_indices, done_event=None): if self.auto_collate_batch: data = [] for _ in batch_indices: - try: - data.append(next(self.dataset_iter)) - except StopIteration: - break + if done_event is None or not done_event.is_set(): + try: + data.append(next(self.dataset_iter)) + except StopIteration: + break + else: + return None if len(data) == 0 or (self.drop_last and len(data) < len(batch_indices)): @@ -101,9 +113,14 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) - def fetch(self, batch_indices): + def fetch(self, batch_indices, done_event=None): if self.auto_collate_batch: - data = [self.dataset[idx] for idx in batch_indices] + data = [] + for idx in batch_indices: + if done_event is None or not done_event.is_set(): + data.append(self.dataset[idx]) + else: + return None global _WARNING_TO_LOG if not isinstance(data[0], (Sequence, Mapping)) \ diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 7fed27ee45978..a246474e21e20 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -91,7 +91,10 @@ def _gen_worker_desc(self, trainer_desc): trainer_desc.device_worker_name = "HogwildWorker" if self._infer: # just ignore feed op for inference model - trainer_desc.hogwild_param.skip_ops.extend(["feed"]) + trainer_desc.hogwild_param.skip_ops.extend([ + "feed", "push_sparse", "push_sparse_v2", "push_dense", + "distributed_push_sparse", "send" + ]) dense_table_set = set() program_id = str(id(self._program)) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index ed351dcbefdbc..709b36ed8e32b 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4381,7 +4381,7 @@ def _create_vars(self, block, ori_block): persistable=source_var.persistable) else: dest_var = block._clone_variable(source_var, False) - dest_var.stop_gradient = source_var.stop_gradient + self._clone_var_attr(dest_var, source_var) # When use with sharding, allreduce_sum and allreduce_max # used for global gradient clip and amp will be added by sharding. op_idx += 1 @@ -4547,9 +4547,14 @@ def _create_var(self, block, ref_var, name, dtype=None): persistable=ref_var.persistable, is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) - new_var.stop_gradient = ref_var.stop_gradient + self._clone_var_attr(new_var, ref_var) return new_var + def _clone_var_attr(self, dest, src): + dest.stop_gradient = src.stop_gradient + if hasattr(src, 'is_distributed'): + dest.is_distributed = src.is_distributed + def _strip_grad_suffix(self, name): """ Strip the grad suffix from the given variable name @@ -5209,6 +5214,8 @@ def _insert_accumulate_gradients_with_fuse(self, main_block, fp16, persistable=True, stop_gradient=False) real_param = main_block.var(param) + if hasattr(real_param, 'is_distributed'): + merged_grad_var.is_distributed = real_param.is_distributed tmp_size = self._get_var_size(real_grad) # two strategies for splitting the grad # 1. the current segment's size reach the user defined grad_size_in_MB diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_op.py index fec15ea7295a0..57c295686f63d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_gather_op.py @@ -23,47 +23,78 @@ from paddle.fluid.core import AnalysisConfig -class TRTGatherTest(InferencePassTest): +class TRTGatherTest1(InferencePassTest): def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data(name='data', shape=[-1, 512], dtype='float32') - index = fluid.data(name='index', shape=[-1], dtype='int32') - scale_out = self.append_gather(data, index) - out = fluid.layers.batch_norm(scale_out, is_test=True) - - index = np.arange(self.num_gather, dtype='int32') - np.random.shuffle(index) + data = fluid.data(name='data', shape=[-1, 128], dtype='float32') + index = fluid.data(name='index', shape=[-1, 1], dtype='int32') + scale_out = fluid.layers.gather(data, index=index) + out = fluid.layers.softmax(input=scale_out) self.feeds = { - "data": np.random.random([self.bs, 512]).astype("float32"), - "index": index, + "data": np.random.random([self.bs, 128]).astype("float32"), + "index": self.index } self.enable_trt = True - self.trt_parameters = TRTGatherTest.TensorRTParam( + self.trt_parameters = TRTGatherTest1.TensorRTParam( 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.dynamic_shape_params = TRTGatherTest1.DynamicShapeParam({ + 'data': [1, 1], + 'index': [1, 1] + }, {'data': [32, 128], + 'index': [3, 1]}, {'data': [32, 128], + 'index': [3, 1]}, False) self.fetch_list = [out] def set_params(self): - self.num_gather = 16 - self.bs = 32 - - def append_gather(self, data, index): - return fluid.layers.gather(data, index=index) + self.index = np.array([[1], [2], [3]], dtype='int32') + self.bs = 4 def test_check_output(self): if core.is_compiled_with_cuda(): use_gpu = True - self.check_output_with_option(use_gpu, flatten=True) + self.check_output_with_option(use_gpu, flatten=False) self.assertTrue( PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) -class TRTGatherTest1(TRTGatherTest): +class TRTGatherTest2(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name='data', shape=[16, 64], dtype='float32') + index = fluid.data(name='index', shape=[2], dtype='int32') + scale_out = fluid.layers.gather(data, index=index) + out = fluid.layers.softmax(input=scale_out) + + self.feeds = { + "data": np.random.random([self.bs, 64]).astype("float32"), + "index": self.index + } + + self.enable_trt = True + self.trt_parameters = TRTGatherTest2.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.dynamic_shape_params = TRTGatherTest2.DynamicShapeParam({ + 'data': [2, 4], + 'index': [1] + }, {'data': [256, 256], + 'index': [4]}, {'data': [64, 32], + 'index': [2]}, False) + self.fetch_list = [out] + def set_params(self): - self.num_gather = 32 - self.bs = 32 + self.index = np.array([1, 4], dtype='int32') + self.bs = 16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, flatten=False) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py index 8c5c3e9219da5..6957a4ceb26de 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -68,7 +68,7 @@ def __init__(self, methodName='runTest'): max_batch_size=4, min_subgraph_size=0, precision=paddle_infer.PrecisionType.Float32, - use_static=True, + use_static=False, use_calib_mode=False) self.dynamic_shape = self.DynamicShapeParam({}, {}, {}, False) self.num_percent_cases = float( @@ -109,7 +109,9 @@ def assert_tensors_near(self, for key, arr in tensor.items(): self.assertTrue( baseline[key].shape == arr.shape, - "The output shape of GPU and TensorRT are not equal.") + "The output shape of GPU and TensorRT are not equal, the baseline shape is " + + str(baseline[key].shape) + ', but the trt shape is ' + + str(arr.shape)) self.assertTrue( np.allclose( baseline[key], arr, atol=atol, rtol=rtol), @@ -259,9 +261,9 @@ def run_test(self, quant=False): if not skip_flag: self.assert_op_size(nodes_num[0], nodes_num[1]) # deserialize test - if nodes_num[0] > 0: - self.run_test_config(model, params, prog_config, - pred_config_deserialize, feed_data) + #if nodes_num[0] > 0: + # self.run_test_config(model, params, prog_config, + # pred_config_deserialize, feed_data) except Exception as e: self.fail_log( str(prog_config) + ' vs ' + self.inference_config_str( diff --git a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py index ff0d57d1d4da1..57293ad5e5633 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_index_select_op_npu.py @@ -35,7 +35,10 @@ def setUp(self): x_np = np.random.random(self.x_shape).astype(self.x_type) index_np = np.random.randint( - low=0, high=self.x_shape[self.dim], size=self.index_size) + low=0, + high=self.x_shape[self.dim], + size=self.index_size, + dtype=self.index_type) # compute real output as baseline. outer_loop = np.prod(self.x_shape[:self.dim]) @@ -56,18 +59,14 @@ def setUp(self): self.attrs = {'dim': self.dim} self.outputs = {'Out': out} - # todo: comment second line when index_select grad npu op is ready. def set_npu(self): self.__class__.use_npu = True - self.__class__.no_need_check_grad = True def test_check_output(self): self.check_output_with_place(self.place) - # todo: replace first line with second line when index_select grad npu op is ready. def test_check_grad(self): - pass - #self.check_grad_with_place(self.place, ['X'], 'Out') + self.check_grad_with_place(self.place, ['X'], 'Out') def config(self): self.x_shape = (100, 4, 5) @@ -86,6 +85,24 @@ def config(self): self.index_size = 10 +class TestNPUIndexSelectCase3(TestNPUIndexSelect): + def config(self): + self.dim = 0 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + +class TestNPUIndexSelectCase4(TestNPUIndexSelect): + def config(self): + self.dim = -1 + self.x_type = np.float32 + self.index_type = np.int32 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + class TestNPUIndexSelectAPI(unittest.TestCase): def input_data(self): self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], diff --git a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py index d2f4eadc9c564..c54a1406e39bf 100644 --- a/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataloader_dataset.py @@ -43,14 +43,18 @@ def test_main(self): class TestDatasetWithDiffOutputPlace(unittest.TestCase): def get_dataloader(self, num_workers): dataset = paddle.vision.datasets.MNIST( - mode='test', transform=transforms.ToTensor()) + mode='test', + transform=transforms.Compose([ + transforms.CenterCrop(20), transforms.RandomResizedCrop(14), + transforms.Normalize(), transforms.ToTensor() + ])) loader = paddle.io.DataLoader( dataset, batch_size=32, num_workers=num_workers, shuffle=True) return loader def run_check_on_cpu(self): paddle.set_device('cpu') - loader = self.get_dataloader(0) + loader = self.get_dataloader(1) for image, label in loader: self.assertTrue(image.place.is_cpu_place()) self.assertTrue(label.place.is_cpu_place()) @@ -66,12 +70,7 @@ def test_single_process(self): for image, label in loader: self.assertTrue(image.place.is_gpu_place()) self.assertTrue(label.place.is_cuda_pinned_place()) - # FIXME(dkp): when input tensor is in GPU place and - # iteration break in the median, it seems the GPU - # tensor put into blocking_queue cannot be safely - # released and may cause ABRT/SEGV, this should - # be fixed - # break + break def test_multi_process(self): # DataLoader with multi-process mode is not supported on MacOs and Windows currently diff --git a/python/paddle/fluid/tests/unittests/test_eigh_op.py b/python/paddle/fluid/tests/unittests/test_eigh_op.py new file mode 100644 index 0000000000000..e434364702525 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eigh_op.py @@ -0,0 +1,199 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +from op_test import OpTest +from gradient_checker import grad_check + + +class TestEighOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "eigh" + self.init_input() + self.init_config() + np.random.seed(123) + out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO) + self.inputs = {"X": self.x_np} + self.attrs = {"UPLO": self.UPLO} + self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v} + + def init_config(self): + self.UPLO = 'L' + + def init_input(self): + self.x_shape = (10, 10) + self.x_type = np.float64 + self.x_np = np.random.random(self.x_shape).astype(self.x_type) + + def test_check_output(self): + self.check_output(no_check_set=['Eigenvectors']) + + def test_grad(self): + self.check_grad(["X"], ["Eigenvalues"]) + + +class TestEighUPLOCase(TestEighOp): + def init_config(self): + self.UPLO = 'U' + + +class TestEighGPUCase(unittest.TestCase): + def setUp(self): + self.x_shape = [32, 32] + self.dtype = "float32" + np.random.seed(123) + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.rtol = 1e-5 + self.atol = 1e-5 + + def test_check_output_gpu(self): + if paddle.is_compiled_with_cuda(): + paddle.disable_static(place=paddle.CUDAPlace(0)) + input_real_data = paddle.to_tensor(self.x_np) + expected_w, expected_v = np.linalg.eigh(self.x_np) + actual_w, actual_v = paddle.linalg.eigh(input_real_data) + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + abs(actual_v.numpy()), + abs(expected_v), + rtol=self.rtol, + atol=self.atol) + + +class TestEighAPI(unittest.TestCase): + def setUp(self): + self.init_input_shape() + self.dtype = "float32" + self.UPLO = 'L' + self.rtol = 1e-6 + self.atol = 1e-6 + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + np.random.seed(123) + self.real_data = np.random.random(self.x_shape).astype(self.dtype) + self.complex_data = np.random.random(self.x_shape).astype( + self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) + self.trans_dims = list(range(len(self.x_shape) - 2)) + [ + len(self.x_shape) - 1, len(self.x_shape) - 2 + ] + + def init_input_shape(self): + self.x_shape = [5, 5] + + def compare_result(self, actual_w, actual_v, expected_w, expected_v): + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + np.testing.assert_allclose( + abs(actual_v), abs(expected_v), rtol=self.rtol, atol=self.atol) + + def check_static_float_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=self.dtype) + output_w, output_v = paddle.linalg.eigh(input_x) + exe = paddle.static.Executor(self.place) + expected_w, expected_v = exe.run(main_prog, + feed={"input_x": self.real_data}, + fetch_list=[output_w, output_v]) + + actual_w, actual_v = np.linalg.eigh(self.real_data) + self.compare_result(actual_w, actual_v, expected_w, expected_v) + + def check_static_complex_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + x_dtype = np.complex64 if self.dtype == "float32" else np.complex128 + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=x_dtype) + output_w, output_v = paddle.linalg.eigh(input_x) + exe = paddle.static.Executor(self.place) + expected_w, expected_v = exe.run( + main_prog, + feed={"input_x": self.complex_data}, + fetch_list=[output_w, output_v]) + actual_w, actual_v = np.linalg.eigh(self.complex_data) + self.compare_result(actual_w, actual_v, expected_w, expected_v) + + def test_in_static_mode(self): + paddle.enable_static() + self.check_static_float_result() + self.check_static_complex_result() + + def test_in_dynamic_mode(self): + paddle.disable_static(self.place) + input_real_data = paddle.to_tensor(self.real_data) + expected_w, expected_v = np.linalg.eigh(self.real_data) + actual_w, actual_v = paddle.linalg.eigh(input_real_data) + self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) + + input_complex_data = paddle.to_tensor(self.complex_data) + expected_w, expected_v = np.linalg.eigh(self.complex_data) + actual_w, actual_v = paddle.linalg.eigh(input_complex_data) + self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) + + def test_eigh_grad(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.complex_data, stop_gradient=False) + w, v = paddle.linalg.eigh(x) + (w.sum() + paddle.abs(v).sum()).backward() + np.testing.assert_allclose( + abs(x.grad.numpy()), + abs(x.grad.numpy().conj().transpose(self.trans_dims)), + rtol=self.rtol, + atol=self.atol) + + +class TestEighBatchAPI(TestEighAPI): + def init_input_shape(self): + self.x_shape = [2, 5, 5] + + +class TestEighAPIError(unittest.TestCase): + def test_error(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + #input maxtrix must greater than 2 dimensions + input_x = paddle.static.data( + name='x_1', shape=[12], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigh, input_x) + + #input matrix must be square matrix + input_x = paddle.static.data( + name='x_2', shape=[12, 32], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigh, input_x) + + #uplo must be in 'L' or 'U' + input_x = paddle.static.data( + name='x_3', shape=[4, 4], dtype="float32") + uplo = 'R' + self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo) + + #x_data cannot be integer + input_x = paddle.static.data( + name='x_4', shape=[4, 4], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.eigh, input_x) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 1dd368f0848c1..6b0a7b79c232c 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -658,6 +658,33 @@ def test_hybrid_with_mp_pp_amp_gclip(self): 'c_gen_nccl_id', 'c_comm_init' ]) + self.assertEqual(main_prog_op_types, [ + 'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream', + 'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant', + 'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale', + 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', + 'update_loss_scaling', 'fill_constant', 'c_allreduce_sum', + 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', + 'elementwise_div', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', + 'momentum', 'momentum', 'momentum', 'momentum', 'momentum', + 'momentum', 'momentum' + ]) + # pp + mp, partial send recv self.assertIn('partial_recv', main_prog_op_types) self.assertIn('partial_allgather', main_prog_op_types) diff --git a/python/paddle/fluid/tests/unittests/test_linalg_cond.py b/python/paddle/fluid/tests/unittests/test_linalg_cond.py new file mode 100644 index 0000000000000..2b42eca38e6fc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linalg_cond.py @@ -0,0 +1,160 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.static as static + +p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf) +p_list_m_n = (None, 2, -2) + + +def test_static_assert_true(self, x_list, p_list): + for p in p_list: + for x in x_list: + with static.program_guard(static.Program(), static.Program()): + input_data = static.data("X", shape=x.shape, dtype=x.dtype) + output = paddle.cond(input_data, p) + exe = static.Executor() + result = exe.run(feed={"X": x}, fetch_list=[output]) + expected_output = np.linalg.cond(x, p) + self.assertTrue(np.allclose(result, expected_output)) + + +def test_dygraph_assert_true(self, x_list, p_list): + for p in p_list: + for x in x_list: + input_tensor = paddle.to_tensor(x) + output = paddle.cond(input_tensor, p) + expected_output = np.linalg.cond(x, p) + self.assertTrue(np.allclose(output, expected_output)) + + +def gen_input(): + # generate square matrix or batches of square matrices + input_1 = np.random.rand(5, 5).astype('float32') + input_2 = np.random.rand(3, 6, 6).astype('float64') + input_3 = np.random.rand(2, 4, 3, 3).astype('float32') + + # generate non-square matrix or batches of non-square matrices + input_4 = np.random.rand(9, 7).astype('float64') + input_5 = np.random.rand(4, 2, 10).astype('float32') + input_6 = np.random.rand(3, 5, 4, 1).astype('float32') + + list_n_n = (input_1, input_2, input_3) + list_m_n = (input_4, input_5, input_6) + return list_n_n, list_m_n + + +def gen_empty_input(): + # generate square matrix or batches of square matrices which are empty tensor + input_1 = np.random.rand(0, 7, 7).astype('float32') + input_2 = np.random.rand(0, 9, 9).astype('float32') + input_3 = np.random.rand(0, 4, 5, 5).astype('float64') + + # generate non-square matrix or batches of non-square matrices which are empty tensor + input_4 = np.random.rand(0, 7, 11).astype('float32') + input_5 = np.random.rand(0, 10, 8).astype('float64') + input_6 = np.random.rand(5, 0, 4, 3).astype('float32') + + list_n_n = (input_1, input_2, input_3) + list_m_n = (input_4, input_5, input_6) + return list_n_n, list_m_n + + +class API_TestStaticCond(unittest.TestCase): + def test_out(self): + paddle.enable_static() + # test calling results of 'cond' in static mode + x_list_n_n, x_list_m_n = gen_input() + test_static_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n) + test_static_assert_true(self, x_list_m_n, p_list_m_n) + + +class API_TestDygraphCond(unittest.TestCase): + def test_out(self): + paddle.disable_static() + # test calling results of 'cond' in dynamic mode + x_list_n_n, x_list_m_n = gen_input() + test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n) + test_dygraph_assert_true(self, x_list_m_n, p_list_m_n) + + +class TestCondAPIError(unittest.TestCase): + def test_dygraph_api_error(self): + paddle.disable_static() + # test raising errors when 'cond' is called in dygraph mode + p_list_error = ('fro_', '_nuc', -0.7, 0, 1.5, 3) + x_list_n_n, x_list_m_n = gen_input() + for p in p_list_error: + for x in (x_list_n_n + x_list_m_n): + x_tensor = paddle.to_tensor(x) + self.assertRaises(ValueError, paddle.cond, x_tensor, p) + + for p in p_list_n_n: + for x in x_list_m_n: + x_tensor = paddle.to_tensor(x) + self.assertRaises(ValueError, paddle.cond, x_tensor, p) + + def test_static_api_error(self): + paddle.enable_static() + # test raising errors when 'cond' is called in static mode + p_list_error = ('f ro', 'fre', 'NUC', -1.6, 0, 5) + x_list_n_n, x_list_m_n = gen_input() + for p in p_list_error: + for x in (x_list_n_n + x_list_m_n): + with static.program_guard(static.Program(), static.Program()): + x_data = static.data("X", shape=x.shape, dtype=x.dtype) + self.assertRaises(ValueError, paddle.cond, x_data, p) + + for p in p_list_n_n: + for x in x_list_m_n: + with static.program_guard(static.Program(), static.Program()): + x_data = static.data("X", shape=x.shape, dtype=x.dtype) + self.assertRaises(ValueError, paddle.cond, x_data, p) + + # it's not supported when input is an empty tensor in static mode + def test_static_empty_input_error(self): + paddle.enable_static() + + x_list_n_n, x_list_m_n = gen_empty_input() + for p in (p_list_n_n + p_list_m_n): + for x in x_list_n_n: + with static.program_guard(static.Program(), static.Program()): + x_data = static.data("X", shape=x.shape, dtype=x.dtype) + self.assertRaises(ValueError, paddle.cond, x_data, p) + + for p in (p_list_n_n + p_list_m_n): + for x in x_list_n_n: + with static.program_guard(static.Program(), static.Program()): + x_data = static.data("X", shape=x.shape, dtype=x.dtype) + self.assertRaises(ValueError, paddle.cond, x_data, p) + + +class TestCondEmptyTensorInput(unittest.TestCase): + def test_dygraph_empty_tensor_input(self): + paddle.disable_static() + # test calling results of 'cond' when input is an empty tensor in dynamic mode + x_list_n_n, x_list_m_n = gen_empty_input() + test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n) + test_dygraph_assert_true(self, x_list_m_n, p_list_m_n) + + +if __name__ == "__main__": + paddle.enable_static() + # paddle.device.set_device("cpu") + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_segment_ops.py b/python/paddle/fluid/tests/unittests/test_segment_ops.py index b58d66676b055..e2aadbedbd07f 100644 --- a/python/paddle/fluid/tests/unittests/test_segment_ops.py +++ b/python/paddle/fluid/tests/unittests/test_segment_ops.py @@ -15,8 +15,11 @@ from __future__ import print_function import unittest -import numpy as np import sys + +import numpy as np +import paddle + from op_test import OpTest @@ -198,5 +201,62 @@ def prepare(self): self.attrs = {'pooltype': "MEAN"} +class API_SegmentOpsTest(unittest.TestCase): + def test_static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[3], dtype='int32') + + res_sum = paddle.incubate.segment_sum(x, y) + res_mean = paddle.incubate.segment_mean(x, y) + res_max = paddle.incubate.segment_max(x, y) + res_min = paddle.incubate.segment_min(x, y) + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + data2 = np.array([0, 0, 1], dtype="int32") + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = exe.run(feed={'x': data1, + 'y': data2}, + fetch_list=[res_sum, res_mean, res_max, res_min]) + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_dygraph(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor( + [[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + y = paddle.to_tensor([0, 0, 1], dtype="int32") + res_sum = paddle.incubate.segment_sum(x, y) + res_mean = paddle.incubate.segment_mean(x, y) + res_max = paddle.incubate.segment_max(x, y) + res_min = paddle.incubate.segment_min(x, y) + + np_sum = np.array([[4, 4, 4], [4, 5, 6]], dtype="float32") + np_mean = np.array([[2, 2, 2], [4, 5, 6]], dtype="float32") + np_max = np.array([[3, 2, 3], [4, 5, 6]], dtype="float32") + np_min = np.array([[1, 2, 1], [4, 5, 6]], dtype="float32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res.numpy(), atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index 584c418675726..fd87e7584cea5 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -32,5 +32,6 @@ 'fusion_lstm', 'softmax_with_cross_entropy', 'svd', + 'eigh', 'class_center_sample', ] diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index efaeda272087f..644b934814020 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -18,7 +18,18 @@ from ..fluid.layer_helper import LayerHelper # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 from .operators import softmax_mask_fuse # noqa: F401 +from .tensor import segment_sum +from .tensor import segment_mean +from .tensor import segment_max +from .tensor import segment_min -__all__ = [ # noqa - 'LookAhead', 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse' +__all__ = [ + 'LookAhead', + 'ModelAverage', + 'softmax_mask_fuse_upper_triangle', + 'softmax_mask_fuse', + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', ] diff --git a/python/paddle/incubate/tensor/__init__.py b/python/paddle/incubate/tensor/__init__.py new file mode 100644 index 0000000000000..ea1018409ab0f --- /dev/null +++ b/python/paddle/incubate/tensor/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .math import segment_sum +from .math import segment_mean +from .math import segment_max +from .math import segment_min + +__all__ = [ + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', +] diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py new file mode 100644 index 0000000000000..f3cb8d50514f0 --- /dev/null +++ b/python/paddle/incubate/tensor/math.py @@ -0,0 +1,225 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'segment_sum', + 'segment_mean', + 'segment_max', + 'segment_min', +] + +import paddle + +from paddle.fluid.layer_helper import LayerHelper, in_dygraph_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle import _C_ops + + +def segment_sum(data, segment_ids, name=None): + """ + Segment Sum Operator. + + This operator sums the elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\sum_{j} data_{j}$ + where sum is over j such that `segment_ids[j] == i`. + + Args: + data (Tensor): A tensor, available data type float32, float64. + segment_ids (Tensor): A 1-D tensor, which have the same size + with the first dimension of input data. + Available data type is int32, int64. + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_sum(data, segment_ids) + #Outputs: [[4., 4., 4.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_sum", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "SUM"}) + return out + + +def segment_mean(data, segment_ids, name=None): + """ + Segment mean Operator. + + Ihis operator calculate the mean value of input `data` which + with the same index in `segment_ids`. + It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$ + where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number + of all index 'segment_ids[j] == i'. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_mean(data, segment_ids) + #Outputs: [[2., 2., 2.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_mean", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MEAN"}) + return out + + +def segment_min(data, segment_ids, name=None): + """ + Segment min operator. + + This operator calculate the minimum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\min_{j} data_{j}$ + where min is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_min(data, segment_ids) + #Outputs: [[1., 2., 1.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_min", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MIN"}) + return out + + +def segment_max(data, segment_ids, name=None): + """ + Segment max operator. + + This operator calculate the maximum elements of input `data` which with + the same index in `segment_ids`. + It computes a tensor such that $out_i = \\min_{j} data_{j}$ + where max is over j such that `segment_ids[j] == i`. + + Args: + data (tensor): a tensor, available data type float32, float64. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. + available data type is int32, int64. + + Returns: + output (Tensor): the reduced result. + + Examples: + + .. code-block:: python + + import paddle + data = paddle.to_tensor([[1, 2, 3], [3, 2, 1], [4, 5, 6]], dtype='float32') + segment_ids = paddle.to_tensor([0, 0, 1], dtype='int32') + out = paddle.incubate.segment_max(data, segment_ids) + #Outputs: [[3., 2., 3.], [4., 5., 6.]] + + """ + if in_dygraph_mode(): + out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") + return out + + check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), + "segment_pool") + + helper = LayerHelper("segment_max", **locals()) + out = helper.create_variable_for_type_inference(dtype=data.dtype) + summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) + helper.append_op( + type="segment_pool", + inputs={"X": data, + "SegmentIds": segment_ids}, + outputs={"Out": out, + "SummedIds": summed_ids}, + attrs={"pooltype": "MAX"}) + return out diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 27dc2595bfb29..74d015b86b5c9 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -14,18 +14,22 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 +from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd +from .tensor.linalg import eigh # noqa: F401 __all__ = [ 'cholesky', #noqa 'norm', + 'cond', 'inv', 'multi_dot', 'matrix_rank', 'svd', - 'matrix_power' + 'matrix_power', + 'eigh' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 8b8601191b4d8..a8897c567c36a 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -36,6 +36,7 @@ from .linalg import matmul # noqa: F401 from .linalg import dot # noqa: F401 from .linalg import norm # noqa: F401 +from .linalg import cond # noqa: F401 from .linalg import transpose # noqa: F401 from .linalg import dist # noqa: F401 from .linalg import t # noqa: F401 @@ -47,6 +48,7 @@ from .linalg import matrix_power # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 +from .linalg import eigh # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -220,6 +222,7 @@ 'matmul', 'dot', 'norm', + 'cond', 'transpose', 'dist', 't', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 96a3610b1894f..fa9da0e579d56 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -543,6 +543,323 @@ def dist(x, y, p=2): return out +def cond(x, p=None, name=None): + """ + + Computes the condition number of a matrix or batches of matrices with respect to a matrix norm ``p``. + + Args: + x (Tensor): The input tensor could be tensor of shape ``(*, m, n)`` where ``*`` is zero or more batch dimensions + for ``p`` in ``(2, -2)``, or of shape ``(*, n, n)`` where every matrix is invertible for any supported ``p``. + And the input data type could be ``float32`` or ``float64``. + p (float|string, optional): Order of the norm. Supported values are `fro`, `nuc`, `1`, `-1`, `2`, `-2`, + `inf`, `-inf`. Default value is `None`, meaning that the order of the norm is `2`. + name (str, optional): The default value is `None`. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: computing results of condition number, its data type is the same as input Tensor ``x``. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x = paddle.to_tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + + # compute conditional number when p is None + out = paddle.linalg.cond(x) + # out.numpy() [1.4142135] + + # compute conditional number when order of the norm is 'fro' + out_fro = paddle.linalg.cond(x, p='fro') + # out_fro.numpy() [3.1622777] + + # compute conditional number when order of the norm is 'nuc' + out_nuc = paddle.linalg.cond(x, p='nuc') + # out_nuc.numpy() [9.2426405] + + # compute conditional number when order of the norm is 1 + out_1 = paddle.linalg.cond(x, p=1) + # out_1.numpy() [2.] + + # compute conditional number when order of the norm is -1 + out_minus_1 = paddle.linalg.cond(x, p=-1) + # out_minus_1.numpy() [1.] + + # compute conditional number when order of the norm is 2 + out_2 = paddle.linalg.cond(x, p=2) + # out_2.numpy() [1.4142135] + + # compute conditional number when order of the norm is -1 + out_minus_2 = paddle.linalg.cond(x, p=-2) + # out_minus_2.numpy() [0.70710677] + + # compute conditional number when order of the norm is inf + out_inf = paddle.linalg.cond(x, p=np.inf) + # out_inf.numpy() [2.] + + # compute conditional number when order of the norm is -inf + out_minus_inf = paddle.linalg.cond(x, p=-np.inf) + # out_minus_inf.numpy() [1.] + + a = paddle.to_tensor(np.random.randn(2, 4, 4).astype('float32')) + # a.numpy() + # [[[ 0.14063153 -0.996288 0.7996131 -0.02571543] + # [-0.16303636 1.5534962 -0.49919784 -0.04402903] + # [-1.1341571 -0.6022629 0.5445269 0.29154757] + # [-0.16816919 -0.30972657 1.7521842 -0.5402487 ]] + # [[-0.58081484 0.12402827 0.7229862 -0.55046535] + # [-0.15178485 -1.1604939 0.75810957 0.30971205] + # [-0.9669573 1.0940945 -0.27363303 -0.35416734] + # [-1.216529 2.0018666 -0.7773689 -0.17556527]]] + a_cond_fro = paddle.linalg.cond(a, p='fro') + # a_cond_fro.numpy() [31.572273 28.120834] + + b = paddle.to_tensor(np.random.randn(2, 3, 4).astype('float64')) + # b.numpy() + # [[[ 1.61707487 0.46829144 0.38130416 0.82546736] + # [-1.72710298 0.08866375 -0.62518804 0.16128892] + # [-0.02822879 -1.67764516 0.11141444 0.3220113 ]] + # [[ 0.22524372 0.62474921 -0.85503233 -1.03960523] + # [-0.76620689 0.56673047 0.85064753 -0.45158196] + # [ 1.47595418 2.23646462 1.5701758 0.10497519]]] + b_cond_2 = paddle.linalg.cond(b, p=2) + # b_cond_2.numpy() [3.30064451 2.51976252] + + """ + + def mat_norm(input, porder=1., axis=None): + """ + NOTE: + Calculate the matrix norm of a square matrix or batches of square matrices, + when porder is in (1, -1, inf, -inf) + """ + reduce_all = True if axis is None or axis == [] else False + axis = axis if axis != None and axis != [] else [0] + keepdim = False + + if in_dygraph_mode(): + abs_out = _C_ops.abs(input) + sum_out = _C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', reduce_all) + if porder == 1 or porder == np.inf: + return _C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim', + keepdim, 'reduce_all', reduce_all) + if porder == -1 or porder == -np.inf: + return _C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim', + keepdim, 'reduce_all', reduce_all) + + block = LayerHelper('norm', **locals()) + abs_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + sum_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='abs', inputs={'X': input}, outputs={'Out': abs_out}) + block.append_op( + type='reduce_sum', + inputs={'X': abs_out}, + outputs={'Out': sum_out}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + if porder == 1 or porder == np.inf: + block.append_op( + type='reduce_max', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={ + 'dim': [-1], + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + if porder == -1 or porder == -np.inf: + block.append_op( + type='reduce_min', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={ + 'dim': [-1], + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + return out + + def fro_norm(input, porder=2, axis=[-1]): + """ + NOTE: + Calculate the frobenius norm of a square matrix or batches of square matrices. + """ + reduce_all = True if axis is None or axis == [] else False + keepdim = False + + if in_dygraph_mode(): + pow_out = _C_ops.pow(input, 'factor', porder) + sum_out_1 = _C_ops.reduce_sum(pow_out, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', reduce_all) + sum_out_2 = _C_ops.reduce_sum(sum_out_1, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', reduce_all) + return _C_ops.pow(sum_out_2, 'factor', float(1. / porder)) + + block = LayerHelper('norm', **locals()) + pow_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + sum_out_1 = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + sum_out_2 = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='pow', + inputs={'X': input}, + outputs={'Out': pow_out}, + attrs={'factor': porder}) + block.append_op( + type='reduce_sum', + inputs={'X': pow_out}, + outputs={'Out': sum_out_1}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + block.append_op( + type='reduce_sum', + inputs={'X': sum_out_1}, + outputs={'Out': sum_out_2}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + block.append_op( + type='pow', + inputs={'X': sum_out_2}, + outputs={'Out': out}, + attrs={'factor': float(1. / porder)}) + return out + + def svd_norm(input, porder, axis=[-1]): + """ + NOTE: + Calculate the matrix norm, which is related to singular values, of a matrix + or batches of matrices, including nuclear norm, 2-norm and (-2)-norm. + """ + reduce_all = True if axis is None or axis == [] else False + keepdim = False + + u, s, vh = svd(input, full_matrices=False) + + if in_dygraph_mode(): + if porder == "nuc": + return _C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim, + 'reduce_all', reduce_all) + max_out = _C_ops.reduce_max(s, 'dim', axis, 'keepdim', keepdim, + 'reduce_all', reduce_all) + min_out = _C_ops.reduce_min(s, 'dim', axis, 'keepdim', keepdim, + 'reduce_all', reduce_all) + if porder == 2: + return _C_ops.elementwise_div(max_out, min_out, 'aixs', axis, + 'use_mkldnn', False) + if porder == -2: + return _C_ops.elementwise_div(min_out, max_out, 'aixs', axis, + 'use_mkldnn', False) + + block = LayerHelper('norm', **locals()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + if porder == "nuc": + block.append_op( + type='reduce_sum', + inputs={'X': s}, + outputs={'Out': out}, + attrs={ + 'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + return out + max_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + min_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op( + type='reduce_max', + inputs={'X': s}, + outputs={'Out': max_out}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + block.append_op( + type='reduce_min', + inputs={'X': s}, + outputs={'Out': min_out}, + attrs={'dim': axis, + 'keep_dim': keepdim, + 'reduce_all': reduce_all}) + if porder == 2: + block.append_op( + type='elementwise_div', + inputs={'X': max_out, + 'Y': min_out}, + outputs={'Out': out}, + attrs={'aixs': axis, + 'use_mkldnn': False}) + return out + if porder == -2: + block.append_op( + type='elementwise_div', + inputs={'X': min_out, + 'Y': max_out}, + outputs={'Out': out}, + attrs={'aixs': axis, + 'use_mkldnn': False}) + return out + + def empty_tensor(input, shape): + if in_dygraph_mode(): + return input.reshape(shape) + raise ValueError("only support x is nonempty tensor in static mode") + + x_shape = list(x.shape) + if not len(x_shape) >= 2: + raise ValueError("input should be a matrix or batches of matrices, " + + "but the dimention of received input is {}".format( + len(x_shape))) + if p == None: + p = 2 + x_size = 0 if (0 in x_shape) else 1 + if p in ("fro", "nuc", 1, -1, np.inf, -np.inf): + if x_shape[len(x_shape) - 1] == x_shape[len(x_shape) - 2]: + if x_size == 0: + return empty_tensor(x, x_shape[:-2]) + x_inv = x.inverse() + if p == "fro": + return fro_norm(x) * fro_norm(x_inv) + if p == "nuc": + return svd_norm(x, p) * svd_norm(x_inv, p) + if p in (1, -1): + return mat_norm( + x, porder=p, axis=[-2]) * mat_norm( + x_inv, porder=p, axis=[-2]) + if p in (np.inf, -np.inf): + return mat_norm( + x, porder=p, axis=[-1]) * mat_norm( + x_inv, porder=p, axis=[-1]) + else: + raise ValueError("only support p is {} when input is a ".format(p) + + "square matrix or batches of square matrices") + elif p in (2, -2): + if x_size == 0: + return empty_tensor(x, x_shape[:-2]) + return svd_norm(x, porder=p) + else: + raise ValueError( + "unsupported {} for p, only supporting ('fro', 'nuc', ".format( + p) + "1, -1, 2, -2, inf, -inf) or none") + + def dot(x, y, name=None): """ This operator calculates inner product for vectors. @@ -1106,7 +1423,7 @@ def svd(x, full_matrices=False, name=None): def matrix_power(x, n, name=None): r""" Computes the n-th power of a square matrix or a batch of square matrices. - + Let :math:`X` be a sqaure matrix or a batch of square matrices, :math:`n` be an exponent, the equation should be: @@ -1251,3 +1568,72 @@ def multi_dot(x, name=None): out = helper.create_variable_for_type_inference(dtype) helper.append_op(type='multi_dot', inputs={"X": x}, outputs={"Out": out}) return out + + +def eigh(x, UPLO='L', name=None): + """ + Compute the eigenvalues and eigenvectors of a + complex Hermitian (conjugate symmetric) or a real symmetric matrix. + + Args: + x (Tensor): A tensor with shape :math:`[*, N, N]` , The data type of the input Tensor x + should be one of float32, float64, complex64, complex128. + UPLO(str, optional): (string, default 'L'), 'L' represents the lower triangular matrix, + "'U' represents the upper triangular matrix.". + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + + out_value(Tensor): A Tensor with shape [*, N] and data type of float32 and float64. The eigenvalues of eigh op. + out_vector(Tensor): A Tensor with shape [*, N, N] and data type of float32,float64,complex64 and complex128. The eigenvectors of eigh op. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + x_data = np.array([[1, -2j], [2j, 5]]) + x = paddle.to_tensor(x_data) + out_value, out_vector = paddle.eigh(x, UPLO='L') + print(out_value) + #[0.17157288, 5.82842712] + print(out_vector) + #[(-0.9238795325112867+0j), (-0.3826834323650898+0j)], + #[ 0.3826834323650898j , -0.9238795325112867j ]] + + """ + if in_dygraph_mode(): + return _C_ops.eigh(x, 'UPLO', UPLO) + + def __check_input(x, UPLO): + x_shape = list(x.shape) + if len(x.shape) < 2: + raise ValueError( + "Input(input) only support >=2 tensor, but received " + "length of Input(input) is %s." % len(x.shape)) + if x_shape[-1] != x_shape[-2]: + raise ValueError( + "The input matrix must be batches of square matrices. But received x's dimention: {}". + format(x_shape)) + if UPLO is not 'L' and UPLO is not 'U': + raise ValueError( + "UPLO must be L or U. But received UPLO is: {}".format(UPLO)) + + __check_input(x, UPLO) + + helper = LayerHelper('eigh', **locals()) + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'eigh') + + out_value = helper.create_variable_for_type_inference(dtype=x.dtype) + out_vector = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='eigh', + inputs={'X': x}, + outputs={'Eigenvalues': out_value, + 'Eigenvectors': out_vector}, + attrs={'UPLO': UPLO}) + return out_value, out_vector diff --git a/python/paddle/vision/models/vgg.py b/python/paddle/vision/models/vgg.py index d526de8208329..755f77aa2971a 100644 --- a/python/paddle/vision/models/vgg.py +++ b/python/paddle/vision/models/vgg.py @@ -21,7 +21,9 @@ model_urls = { 'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams', - '89bbffc0f87d260be9b8cdc169c991c4') + '89bbffc0f87d260be9b8cdc169c991c4'), + 'vgg19': ('https://paddle-hapi.bj.bcebos.com/models/vgg19.pdparams', + '23b18bb13d8894f60f54e642be79a0dd') } diff --git a/python/setup.py.in b/python/setup.py.in index 6d3e6201dc772..1b2897f230fbe 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -162,6 +162,7 @@ packages=['paddle', 'paddle.incubate.optimizer', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', + 'paddle.incubate.tensor', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic',