From 02a6d49a9ab5135cf765e53575790fa8984403a4 Mon Sep 17 00:00:00 2001 From: zyfncg <1370305206@qq.com> Date: Tue, 15 Jun 2021 10:43:10 +0800 Subject: [PATCH 01/18] Add digamma_op and unittest (#33278) * Add digamma_op and unittest * add digamma_op api * remove special DigammaCudaKernel and correct some docs * remove unused headers * fix api doc error --- paddle/fluid/operators/digamma_op.cc | 100 +++++++++++++++ paddle/fluid/operators/digamma_op.cu | 26 ++++ paddle/fluid/operators/digamma_op.h | 99 +++++++++++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_digamma_op.py | 119 ++++++++++++++++++ python/paddle/tensor/__init__.py | 4 +- python/paddle/tensor/math.py | 36 ++++++ 7 files changed, 385 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/digamma_op.cc create mode 100644 paddle/fluid/operators/digamma_op.cu create mode 100644 paddle/fluid/operators/digamma_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_digamma_op.py diff --git a/paddle/fluid/operators/digamma_op.cc b/paddle/fluid/operators/digamma_op.cc new file mode 100644 index 0000000000000..b1a58817e0604 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.cc @@ -0,0 +1,100 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/digamma_op.h" + +namespace paddle { +namespace operators { + +class DigammaOp : public framework::OperatorWithKernel { + public: + DigammaOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Digamma"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Digamma"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class DigammaOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of digamma operator."); + AddOutput("Out", "(Tensor), The output tensor of digamma operator."); + AddComment(R"DOC( +Digamma Operator. + +This operator is used to perform elementwise digamma for input $X$. +$$out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) }$$ + +)DOC"); + } +}; + +class DigammaGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "DigammaGrad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DigammaGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "DigammaGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +template +class DigammaGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("digamma_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(digamma, ops::DigammaOp, ops::DigammaOpMaker, + ops::DigammaGradOpMaker, + ops::DigammaGradOpMaker); +REGISTER_OPERATOR(digamma_grad, ops::DigammaGradOp); + +REGISTER_OP_CPU_KERNEL( + digamma, ops::DigammaKernel, + ops::DigammaKernel); + +REGISTER_OP_CPU_KERNEL( + digamma_grad, + ops::DigammaGradKernel, + ops::DigammaGradKernel); diff --git a/paddle/fluid/operators/digamma_op.cu b/paddle/fluid/operators/digamma_op.cu new file mode 100644 index 0000000000000..5f2f59ba520d0 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/digamma_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + digamma, ops::DigammaKernel, + ops::DigammaKernel); + +REGISTER_OP_CUDA_KERNEL( + digamma_grad, + ops::DigammaGradKernel, + ops::DigammaGradKernel); diff --git a/paddle/fluid/operators/digamma_op.h b/paddle/fluid/operators/digamma_op.h new file mode 100644 index 0000000000000..f82628f020480 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct DigammaFunctor { + DigammaFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = Eigen::numext::digamma(input_[idx]); + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + +template +struct DigammaGradFunctor { + DigammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = dout_[idx] * Eigen::numext::polygamma(T(1), x_[idx]); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; + +using Tensor = framework::Tensor; + +template +class DigammaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace(), + size_t(x->numel() * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + DigammaFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class DigammaGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* d_out = context.Input(framework::GradVarName("Out")); + const Tensor* x = context.Input("X"); + auto* d_x = context.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data(); + auto* x_data = x->data(); + auto* dx_data = d_x->mutable_data( + context.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + DigammaGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3c16f327df4c2..738de4e393d77 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -205,6 +205,7 @@ from .tensor.math import prod # noqa: F401 from .tensor.math import broadcast_shape # noqa: F401 from .tensor.math import conj # noqa: F401 +from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 @@ -489,5 +490,6 @@ 'log10', 'concat', 'check_shape', + 'digamma', 'standard_normal' ] diff --git a/python/paddle/fluid/tests/unittests/test_digamma_op.py b/python/paddle/fluid/tests/unittests/test_digamma_op.py new file mode 100644 index 0000000000000..86f59af19346c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_digamma_op.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import math +import numpy as np +from scipy.special import psi +import paddle +import paddle.fluid as fluid +import paddle.static as static +from op_test import OpTest + + +class TestDigammaOp(OpTest): + def setUp(self): + # switch to static + paddle.enable_static() + + self.op_type = 'digamma' + self.init_dtype_type() + shape = (5, 32) + data = np.random.random(shape).astype(self.dtype) + 1 + self.inputs = {'X': data} + result = np.ones(shape).astype(self.dtype) + result = psi(data) + self.outputs = {'Out': result} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestDigammaOpFp32(TestDigammaOp): + def init_dtype_type(self): + self.dtype = np.float32 + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestDigammaAPI(unittest.TestCase): + def setUp(self): + # switch to static + paddle.enable_static() + # prepare test attrs + self.dtypes = ["float32", "float64"] + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + self._shape = [8, 3, 32, 32] + + def test_in_static_mode(self): + def init_input_output(dtype): + input = np.random.random(self._shape).astype(dtype) + return {'x': input}, psi(input) + + for dtype in self.dtypes: + input_dict, sc_res = init_input_output(dtype) + for place in self.places: + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=dtype) + out = paddle.digamma(x) + + exe = static.Executor(place) + out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + self.assertEqual( + np.allclose( + out_value[0], sc_res, rtol=1e-5), True) + + def test_in_dynamic_mode(self): + for dtype in self.dtypes: + input = np.random.random(self._shape).astype(dtype) + sc_res = psi(input) + for place in self.places: + # it is more convenient to use `guard` than `enable/disable_**` here + with fluid.dygraph.guard(place): + input_t = paddle.to_tensor(input) + res = paddle.digamma(input_t).numpy() + self.assertEqual(np.allclose(res, sc_res, rtol=1e-05), True) + + def test_name_argument(self): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0]) + out = paddle.digamma(x, name="digamma_res") + self.assertTrue("digamma_res" in out.name) + + def test_dtype_error(self): + # in static mode + with self.assertRaises(TypeError): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype="int32") + out = paddle.digamma(x, name="digamma_res") + + # in dynamic mode + with self.assertRaises(RuntimeError): + with fluid.dygraph.guard(): + input = np.random.random(self._shape).astype("int32") + input_t = paddle.to_tensor(input) + res = paddle.digamma(input_t) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 0b8d2be24f3cb..8c83b1786b01e 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -162,6 +162,7 @@ from .math import any # noqa: F401 from .math import broadcast_shape # noqa: F401 from .math import conj # noqa: F401 +from .math import digamma # noqa: F401 from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 @@ -347,5 +348,6 @@ 'rank', 'shape', 'real', - 'imag' + 'imag', + 'digamma' ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 15d0cd0146ab0..a9e24949aae2b 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2283,6 +2283,42 @@ def conj(x, name=None): helper.append_op(type='conj', inputs={'X': x}, outputs={'Out': [out]}) return out +def digamma(x, name=None): + r""" + Calculates the digamma of the given input tensor, element-wise. + + .. math:: + Out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) } + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Tensor, the digamma of the input Tensor, the shape and data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.to_tensor([[1, 1.5], [0, -2.2]], dtype='float32') + res = paddle.digamma(data) + print(res) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[-0.57721591, 0.03648996], + # [ nan , 5.32286835]]) + """ + + if in_dygraph_mode(): + return core.ops.digamma(x) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'digamma') + helper = LayerHelper('digamma', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='digamma', inputs={'X': x}, outputs={'Out': out}) + return out + def neg(x, name=None): """ This function computes the negative of the Tensor elementwisely. From 606939de76af62afc1d4170b6b2e53e4ba743a74 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Tue, 15 Jun 2021 10:56:55 +0800 Subject: [PATCH 02/18] Support reduce_sum_op float16 (#32966) * add reduce_sum_op by add self-kernel * set all ReduceKernel MPType for accuracy * add float16 test script which input is integer number * solve reduce sum float16 check_grad problem * solve conflict and change test script for CI * change kernel register for CI * remove all useless template --- paddle/fluid/operators/kron_op.h | 14 +- paddle/fluid/operators/matmul_v2_op.h | 12 +- paddle/fluid/operators/pool_op.h | 6 +- .../fluid/operators/reduce_ops/cub_reduce.h | 167 +++++++++++++----- .../operators/reduce_ops/reduce_sum_op.cc | 3 + .../operators/reduce_ops/reduce_sum_op.cu | 14 +- .../reduce_ops/reduce_sum_op.part.cu | 1 + paddle/fluid/operators/trace_op.cu | 10 +- python/paddle/fluid/layers/nn.py | 3 +- .../fluid/tests/unittests/test_reduce_op.py | 50 ++++++ 10 files changed, 208 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 6c3bad4e1bdcd..ea2050fe8e61e 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -237,11 +237,13 @@ struct KronGradElemFunctor> { const int ndims_; }; -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -312,13 +314,13 @@ struct KronGradOpFunctor { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { - TensorReduce>( - dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + TensorReduce( + dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } if (dy) { - TensorReduce>( - dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + TensorReduce( + dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } #else diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 6061679b28893..5b114f381996e 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -34,11 +34,13 @@ namespace operators { using framework::Tensor; -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -47,9 +49,9 @@ void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); - TensorReduce>( - *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduce(*input, output, reduce_dims, + static_cast(0), cub::Sum(), + IdentityFunctor(), stream); #else ReduceKernelFunctor( input, output, reduce_dims, true, false, ctx) diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 9117b1b95ed26..e84c92d9a1624 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -31,7 +31,11 @@ namespace operators { template struct DivideFunctor { HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} - HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } + + template + HOSTDEVICE inline U operator()(const U& x) const { + return x * static_cast(n_inv); + } private: T n_inv; diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 9e1aed5dde4b6..0aab680e13dc1 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -31,6 +31,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" namespace paddle { namespace operators { @@ -66,39 +67,66 @@ struct Array { T data_[ElementCount]; }; +// reduce the 1d array to one element +template +__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, MPType init, + int reduce_num) { + int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + MPType local_data = init; + for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) { + local_data = static_cast( + reducer(local_data, static_cast(transformer(x[i])))); + } + __syncthreads(); + + local_data = BlockReduce(temp_storage).Reduce(local_data, reducer); + + if (threadIdx.x == 0) { + y[blockIdx.x] = static_cast(local_data); + } +} + // reduce the last axis of 2d array -template +template __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, Ty init, + TransformOp transformer, MPType init, int reduce_num) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; + __shared__ + typename cub::BlockReduce::TempStorage temp_storage; int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; - Ty reduce_var = init; + MPType reduce_var = init; for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) reduce_var = - reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); + reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + reduce_var = cub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; + y[blockIdx.x] = static_cast(reduce_var); } } -template +template __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, Ty init, int reduce_num, - Array x_strides, + TransformOp transformer, MPType init, + int reduce_num, Array x_strides, Array reduce_dim, Array reduce_strides, Array left_dim, Array left_strides) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; + __shared__ + typename cub::BlockReduce::TempStorage temp_storage; Array sub_index; int left_idx = blockIdx.x; for (int i = 0; i < Rank - ReduceRank; ++i) { @@ -114,7 +142,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, int idx_x = 0; for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - Ty reduce_var = static_cast(transformer(x[idx_x])); + MPType reduce_var = static_cast(transformer(x[idx_x])); for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { int reduce_idx = i; @@ -125,16 +153,16 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, int idx_x = 0; for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - reduce_var = static_cast( - reducer(reduce_var, static_cast(transformer(x[idx_x])))); + reduce_var = static_cast( + reducer(reduce_var, static_cast(transformer(x[idx_x])))); } __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + reduce_var = cub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; + y[blockIdx.x] = static_cast(reduce_var); } } @@ -192,6 +220,53 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { } } +template +typename std::enable_if::value, + void>::type +LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, + const platform::Place& place, const ReduceOp& reducer, + const TransformOp& transformer, const MPType& init, + int reduce_num, gpuStream_t stream) { + cub::TransformInputIterator trans_x(x_data, + transformer); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), place); + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); +} + +template +typename std::enable_if::value, + void>::type +LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, + const platform::Place& place, const ReduceOp& reducer, + const TransformOp& transformer, const MPType& init, + int reduce_num, gpuStream_t stream) { + int element_per_block = BlockDim * 10; + int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block; + + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim( + {static_cast(block_per_grid * sizeof(MPType))}), + place); + + // each block reduce number to interim result + ReduceKernel1D<<>>( + x_data, temp_storage, reducer, transformer, init, reduce_num); + // reduce all number to final result + ReduceKernel1D<<<1, BlockDim, 0, stream>>>( + temp_storage, y_data, reducer, transformer, init, block_per_grid); +} + template static void TensorReduceImpl( @@ -201,45 +276,40 @@ static void TensorReduceImpl( const std::vector& reduce_dim, const std::vector& reduce_strides, const std::vector& left_dim, const std::vector& left_strides, gpuStream_t stream) { + using MPType = typename details::MPTypeTrait::Type; + MPType init_mp = static_cast(init); + #define CUB_RANK_CASE(i, ...) \ case i: { \ constexpr auto kRank = i; \ switch (reduce_rank) { __VA_ARGS__; } \ } break -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kReduceRank = i; \ - ReduceKernel<<>>( \ - x_data, y_data, reducer, transformer, init, reduce_num, \ - Array::From(x_strides), \ - Array::From(reduce_dim), \ - Array::From(reduce_strides), \ - Array::From(left_dim), \ - Array::From(left_strides)); \ +#define CUB_REDUCE_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto kReduceRank = i; \ + ReduceKernel<<>>( \ + x_data, y_data, reducer, transformer, init_mp, reduce_num, \ + Array::From(x_strides), \ + Array::From(reduce_dim), \ + Array::From(reduce_strides), \ + Array::From(left_dim), \ + Array::From(left_strides)); \ } break int rank = x_strides.size(); int reduce_rank = reduce_strides.size(); if (rank == reduce_rank) { - cub::TransformInputIterator trans_x( - x_data, transformer); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), - place); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); + LaunchCubReduceKernel( + x_data, y_data, place, reducer, transformer, init_mp, reduce_num, + stream); return; } if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { - ReduceKernel2D<<>>( - x_data, y_data, reducer, transformer, init, reduce_num); + x_data, y_data, reducer, transformer, init_mp, reduce_num); return; } /* @@ -366,8 +436,7 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, #undef CUB_BLOCK_DIM_CASE } -template class TransformOp> +template class TransformOp> struct TensorReduceFunctor { const framework::Tensor& x; framework::Tensor* y; @@ -389,9 +458,9 @@ struct TensorReduceFunctor { void apply() const { const Ty& init_cast = static_cast(init); - TensorReduce>( - x, y, origin_reduce_dims, init_cast, reducer, TransformOp(), - stream); + TensorReduce>(x, y, origin_reduce_dims, + init_cast, reducer, + TransformOp(), stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 74e7db649d5ab..9e4cc8e213c61 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -115,6 +115,8 @@ REGISTER_OP_CPU_KERNEL( ops::SumFunctor>, ops::ReduceKernel, + ops::ReduceKernel, ops::ReduceKernel, ops::ReduceKernel, @@ -133,6 +135,7 @@ using CPUReduceSumGradKernel = REGISTER_OP_CPU_KERNEL( reduce_sum_grad, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, + CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel>, CPUReduceSumGradKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index dd16ca4e393e7..efbafe4aa8c3e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -18,12 +18,13 @@ namespace paddle { namespace operators { -template +template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x); + template + HOSTDEVICE inline Tout operator()(const U& x) const { + return static_cast(x); } }; @@ -62,9 +63,9 @@ class ReduceSumKernel : public framework::OpKernel { *input, output, reduce_dims, static_cast(0.0), cub::Sum(), stream)); } else { - TensorReduce>( + TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } } }; @@ -74,7 +75,8 @@ class ReduceSumKernel : public framework::OpKernel { REGISTER_OP_CUDA_KERNEL( reduce_sum, ops::ReduceSumKernel, ops::ReduceSumKernel, - ops::ReduceSumKernel, ops::ReduceSumKernel, + ops::ReduceSumKernel, + ops::ReduceSumKernel, ops::ReduceSumKernel, ops::ReduceSumKernel, ops::ReduceSumKernel>, ops::ReduceSumKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 230bae0cdd4b1..419b8ce276526 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,6 +23,7 @@ using CUDAReduceSumGradKernel = REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel>, CUDAReduceSumGradKernel>); diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 6798521c8f747..336c1c40832b9 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -20,11 +20,13 @@ namespace paddle { namespace operators { -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -45,9 +47,9 @@ class TraceCUDAKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); - TensorReduce>( + TensorReduce( diag, out, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e02edb72ce1f7..7e50646c0c42d 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4424,7 +4424,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): if dim == None or dim == [] or len(dim) == len(input.shape) else False } check_variable_and_dtype( - input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum') + input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'reduce_sum') helper = LayerHelper('reduce_sum', **locals()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 912df563fcdbf..2dd5bcb811364 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -37,6 +37,56 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestSumOp_fp16(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.uniform(0, 0.1, (5, 6, 10)).astype("float16") + } + self.attrs = {'dim': [0, 1, 2]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + self.gradient = self.calc_gradient() + + def test_check_output(self): + self.check_output() + + def calc_gradient(self): + x = self.inputs["X"] + grad = np.ones(x.shape, dtype=x.dtype) + return grad, + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + +class TestSumOp_fp16_withInt(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + # ref to https://en.wikipedia.org/wiki/Half-precision_floating-point_format + # Precision limitations on integer values between 0 and 2048 can be exactly represented + 'X': np.random.randint(0, 30, (10, 10)).astype("float16") + } + self.attrs = {'dim': [0, 1]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + self.gradient = self.calc_gradient() + + def test_check_output(self): + self.check_output() + + def calc_gradient(self): + x = self.inputs["X"] + grad = np.ones(x.shape, dtype=x.dtype) + return grad, + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + class TestSumOp5D(OpTest): def setUp(self): self.op_type = "reduce_sum" From 1f8de08067e905377a6a637be18b60612b55bf53 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 15 Jun 2021 11:11:40 +0800 Subject: [PATCH 03/18] add the support for the bool in compare ops add the support for the bool in compare ops --- .../operators/controlflow/compare_all_op.cc | 20 ++--- .../operators/controlflow/compare_all_op.cu | 21 +++--- .../fluid/operators/controlflow/compare_op.cu | 1 + .../fluid/operators/controlflow/compare_op.h | 3 + .../fluid/tests/unittests/test_compare_op.py | 32 ++++++++ .../tests/unittests/test_compare_reduce_op.py | 29 +++++++- python/paddle/tensor/logic.py | 74 ++++++++++--------- 7 files changed, 127 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cc b/paddle/fluid/operators/controlflow/compare_all_op.cc index adacf70f5e145..9442c7583d98f 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cc +++ b/paddle/fluid/operators/controlflow/compare_all_op.cc @@ -135,15 +135,17 @@ class CompareReduceOp : public framework::OperatorWithKernel { ::paddle::framework::EmptyGradOpMaker, \ ::paddle::framework::EmptyGradOpMaker); -#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \ - REGISTER_OP_CPU_KERNEL( \ - op_type, ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ - ::paddle::platform::CPUDeviceContext, functor>, \ - ::paddle::operators::CompareReduceOpKernel< \ +#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \ + REGISTER_OP_CPU_KERNEL( \ + op_type, ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ ::paddle::platform::CPUDeviceContext, functor>); REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y"); diff --git a/paddle/fluid/operators/controlflow/compare_all_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu index e3c920f78c45b..3753ed6b15f1e 100644 --- a/paddle/fluid/operators/controlflow/compare_all_op.cu +++ b/paddle/fluid/operators/controlflow/compare_all_op.cu @@ -85,15 +85,18 @@ class CompareReduceOpKernel } // namespace operators } // namespace paddle -#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ - REGISTER_OP_CUDA_KERNEL( \ - op_type, paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ - paddle::platform::CUDADeviceContext, functor>, \ - paddle::operators::CompareReduceOpKernel< \ +#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ paddle::platform::CUDADeviceContext, functor>); + REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_op.cu b/paddle/fluid/operators/controlflow/compare_op.cu index cc0c46adb119a..6f3a615edb44b 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cu +++ b/paddle/fluid/operators/controlflow/compare_op.cu @@ -82,6 +82,7 @@ class CompareOpKernel #define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ REGISTER_OP_CUDA_KERNEL( \ op_type, \ + ops::CompareOpKernel, void>, \ ops::CompareOpKernel, void>, \ ops::CompareOpKernel, void>, \ ops::CompareOpKernel, void>, \ diff --git a/paddle/fluid/operators/controlflow/compare_op.h b/paddle/fluid/operators/controlflow/compare_op.h index ff929ee7dfce7..36185322a96b8 100644 --- a/paddle/fluid/operators/controlflow/compare_op.h +++ b/paddle/fluid/operators/controlflow/compare_op.h @@ -98,6 +98,9 @@ class CompareOpKernel #define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \ REGISTER_OP_##dev##_KERNEL(op_type, \ + ::paddle::operators::CompareOpKernel< \ + ::paddle::platform::dev##DeviceContext, \ + functor, inverse_functor>, \ ::paddle::operators::CompareOpKernel< \ ::paddle::platform::dev##DeviceContext, \ functor, inverse_functor>, \ diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index a2dd7e49ac4cc..7a14267588022 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -155,6 +155,38 @@ def test_broadcast_api_3(self): fetch_list=[out]) self.assertEqual((res == real_result).all(), True) + def test_bool_api_4(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') + y = paddle.static.data(name='y', shape=[3, 1], dtype='bool') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.array([True, False, True]).astype(np.bool) + input_y = np.array([True, True, False]).astype(np.bool) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_bool_broadcast_api_4(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') + y = paddle.static.data(name='y', shape=[1], dtype='bool') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.array([True, False, True]).astype(np.bool) + input_y = np.array([True]).astype(np.bool) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + def test_attr_name(self): paddle.enable_static() with program_guard(Program(), Program()): diff --git a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py index 67fe5c81ddc29..056d1687bbf84 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py @@ -92,9 +92,28 @@ def test_output(self): globals()[cls_name] = Cls +def create_test_dim1_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + x = y = np.random.random(size=(1)).astype(typename) + x = np.array([True, False, True]).astype(typename) + x = np.array([False, False, True]).astype(typename) + z = callback(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': z} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all') + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y)) -for _type_name in {'float32', 'float64', 'int32', 'int64'}: +for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}: create_test_not_equal_class('equal_all', _type_name, np_equal) create_test_equal_class('equal_all', _type_name, np_equal) create_test_dim1_class('equal_all', _type_name, np_equal) @@ -107,6 +126,14 @@ def test_name(self): out = paddle.equal_all(x, y, name='equal_res') assert 'equal_res' in out.name + def test_dynamic_api(self): + paddle.disable_static() + x = paddle.ones(shape=[10, 10], dtype="int32") + y = paddle.ones(shape=[10, 10], dtype="int32") + out = paddle.equal_all(x, y) + assert out.numpy()[0] == True + paddle.enable_static() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index bdf2c477d8658..f948eeb9a48eb 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -38,8 +38,8 @@ def equal_all(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): Tensor, data type is float32, float64, int32, int64. - y(Tensor): Tensor, data type is float32, float64, int32, int64. + x(Tensor): Tensor, data type is bool, float32, float64, int32, int64. + y(Tensor): Tensor, data type is bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -59,6 +59,8 @@ def equal_all(x, y, name=None): result2 = paddle.equal_all(x, z) print(result2) # result2 = [False ] """ + if in_dygraph_mode(): + return core.ops.equal_all(x, y) helper = LayerHelper("equal_all", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') @@ -152,8 +154,8 @@ def equal(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): Tensor, data type is float32, float64, int32, int64. - y(Tensor): Tensor, data type is float32, float64, int32, int64. + x(Tensor): Tensor, data type is bool, float32, float64, int32, int64. + y(Tensor): Tensor, data type is bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -174,10 +176,10 @@ def equal(x, y, name=None): if in_dygraph_mode(): return core.ops.equal(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], - "equal") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], - "equal") + check_variable_and_dtype( + x, "x", ["bool", "float32", "float64", "int32", "int64"], "equal") + check_variable_and_dtype( + y, "y", ["bool", "float32", "float64", "int32", "int64"], "equal") helper = LayerHelper("equal", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') out.stop_gradient = True @@ -196,8 +198,8 @@ def greater_equal(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -216,9 +218,11 @@ def greater_equal(x, y, name=None): if in_dygraph_mode(): return core.ops.greater_equal(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], + check_variable_and_dtype(x, "x", + ["bool", "float32", "float64", "int32", "int64"], "greater_equal") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], + check_variable_and_dtype(y, "y", + ["bool", "float32", "float64", "int32", "int64"], "greater_equal") helper = LayerHelper("greater_equal", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') @@ -240,8 +244,8 @@ def greater_than(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -260,9 +264,11 @@ def greater_than(x, y, name=None): if in_dygraph_mode(): return core.ops.greater_than(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], + check_variable_and_dtype(x, "x", + ["bool", "float32", "float64", "int32", "int64"], "greater_than") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], + check_variable_and_dtype(y, "y", + ["bool", "float32", "float64", "int32", "int64"], "greater_than") helper = LayerHelper("greater_than", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') @@ -284,8 +290,8 @@ def less_equal(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -305,10 +311,10 @@ def less_equal(x, y, name=None): if in_dygraph_mode(): return core.ops.less_equal(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], - "less_equal") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], - "less_equal") + check_variable_and_dtype( + x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_equal") + check_variable_and_dtype( + y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_equal") helper = LayerHelper("less_equal", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') out.stop_gradient = True @@ -327,8 +333,8 @@ def less_than(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -348,10 +354,10 @@ def less_than(x, y, name=None): if in_dygraph_mode(): return core.ops.less_than(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], - "less_than") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], - "less_than") + check_variable_and_dtype( + x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_than") + check_variable_and_dtype( + y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_than") helper = LayerHelper("less_than", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') out.stop_gradient = True @@ -370,8 +376,8 @@ def not_equal(x, y, name=None): **NOTICE**: The output of this OP has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -391,10 +397,10 @@ def not_equal(x, y, name=None): if in_dygraph_mode(): return core.ops.not_equal(x, y) - check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], - "not_equal") - check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], - "not_equal") + check_variable_and_dtype( + x, "x", ["bool", "float32", "float64", "int32", "int64"], "not_equal") + check_variable_and_dtype( + y, "y", ["bool", "float32", "float64", "int32", "int64"], "not_equal") helper = LayerHelper("not_equal", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') out.stop_gradient = True From c5a6ae4c3f6368053594f49d9bed6956a1fca38c Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Tue, 15 Jun 2021 11:16:49 +0800 Subject: [PATCH 04/18] 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape (#33535) * 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape * remove useless code --- paddle/fluid/inference/tensorrt/op_teller.cc | 2 +- .../tensorrt/plugin/layer_norm_op_plugin.cu | 62 +------------ .../tensorrt/plugin/layer_norm_op_plugin.h | 31 +------ paddle/fluid/operators/layer_norm_op.cu | 92 ------------------- 4 files changed, 4 insertions(+), 183 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 1bbfba7e419fb..59b196e3d92be 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -694,7 +694,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; // Paddle-TRT does not support the input tensors: Shape and ShapeTensor } else if (desc.Input("Shape").size() >= 1 || - desc.Input("ShapeTensor").size() >= 1 || with_dynamic_shape) { + desc.Input("ShapeTensor").size() >= 1) { return false; } else { std::vector shape = diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index d67820a6f0af4..f9341613a0f55 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -182,69 +182,9 @@ int LayerNormPluginDynamic::enqueue( paddle::operators::LayerNormDirectCUDAFunctor layer_norm; layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, variance_d, begin_norm_axis, eps); - } else if (input_type == nvinfer1::DataType::kHALF) { -#ifdef TRT_PLUGIN_FP16_AVALIABLE - VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp16"; - const half *input = reinterpret_cast(inputs[0]); - half *output = static_cast(outputs[0]); - size_t mean_shape_product = 1; - for (auto s : mean_shape_) { - mean_shape_product *= s; - } - size_t variance_shape_product = 1; - for (auto s : variance_shape_) { - variance_shape_product *= s; - } - if (!scale_gpu_half_d_) { - cudaMalloc(&scale_gpu_half_d_, feature_size * sizeof(half)); - } - if (!bias_gpu_half_d_) { - cudaMalloc(&bias_gpu_half_d_, feature_size * sizeof(half)); - } - if (!mean_gpu_half_d_) { - cudaMalloc(&mean_gpu_half_d_, mean_shape_product * sizeof(half)); - } - if (!variance_gpu_half_d_) { - cudaMalloc(&variance_gpu_half_d_, variance_shape_product * sizeof(half)); - } - - half *scale_cpu_half = - static_cast(malloc(feature_size * sizeof(half))); - half *bias_cpu_half = - static_cast(malloc(feature_size * sizeof(half))); - PADDLE_ENFORCE_EQ( - scale_cpu_half && bias_cpu_half, true, - platform::errors::Unavailable("Out of memory, malloc size %d.", - feature_size * sizeof(half))); - - for (int i = 0; i < feature_size; i++) { - scale_cpu_half[i] = static_cast(scale_[i]); - bias_cpu_half[i] = static_cast(bias_[i]); - } - cudaMemcpyAsync(scale_gpu_half_d_, scale_cpu_half, - sizeof(half) * feature_size, cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(bias_gpu_half_d_, bias_cpu_half, - sizeof(half) * feature_size, cudaMemcpyHostToDevice, - stream); - free(scale_cpu_half); - free(bias_cpu_half); - - paddle::operators::LayerNormDirectCUDAFunctor layer_norm; - layer_norm(stream, input, input_shape, bias_gpu_half_d_, scale_gpu_half_d_, - output, mean_gpu_half_d_, variance_gpu_half_d_, begin_norm_axis, - eps); -#else - PADDLE_THROW(platform::errors::Fatal( - "The layer_norm tensorRT plugin should be " - "complied with CUDA version >= 10.0 when running with fp16. " - "Please recomplie it or try to use fp32 by set " - "config.SetTRTDynamicShapeInfo(min_input_shape, " - "max_input_shape, opt_input_shape, true")); -#endif } else { PADDLE_THROW(platform::errors::Fatal( - "The LayerNorm TRT Plugin's input type should be float or half.")); + "The LayerNorm TRT Plugin's input type should be float.")); } return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h index 1a6125b0e16ff..9c4c31b61e128 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h @@ -114,22 +114,14 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { : begin_norm_axis_(begin_norm_axis), eps_(eps), mean_shape_(mean_shape), - variance_shape_(variance_shape), - scale_gpu_half_d_(nullptr), - bias_gpu_half_d_(nullptr), - mean_gpu_half_d_(nullptr), - variance_gpu_half_d_(nullptr) { + variance_shape_(variance_shape) { bias_.resize(bias_num); scale_.resize(scale_num); std::copy(bias, bias + bias_num, bias_.data()); std::copy(scale, scale + scale_num, scale_.data()); } - LayerNormPluginDynamic(void const* serialData, size_t serialLength) - : scale_gpu_half_d_(nullptr), - bias_gpu_half_d_(nullptr), - mean_gpu_half_d_(nullptr), - variance_gpu_half_d_(nullptr) { + LayerNormPluginDynamic(void const* serialData, size_t serialLength) { DeserializeValue(&serialData, &serialLength, &bias_); DeserializeValue(&serialData, &serialLength, &scale_); DeserializeValue(&serialData, &serialLength, &begin_norm_axis_); @@ -190,21 +182,6 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { const nvinfer1::DataType* inputTypes, int nbInputs) const override; - ~LayerNormPluginDynamic() { - if (scale_gpu_half_d_) { - cudaFree(scale_gpu_half_d_); - } - if (bias_gpu_half_d_) { - cudaFree(bias_gpu_half_d_); - } - if (mean_gpu_half_d_) { - cudaFree(mean_gpu_half_d_); - } - if (variance_gpu_half_d_) { - cudaFree(variance_gpu_half_d_); - } - } - void destroy() override { delete this; } private: @@ -218,10 +195,6 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { float eps_; std::vector mean_shape_; std::vector variance_shape_; - half* scale_gpu_half_d_; - half* bias_gpu_half_d_; - half* mean_gpu_half_d_; - half* variance_gpu_half_d_; }; class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index b65ae01ddf919..f955011675cf5 100755 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -243,73 +243,6 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, } } -template -__global__ void LayerNormForwardFP16(const T *x, const U *scale, const U *bias, - T *y, U *mean, U *var, float epsilon, - int feature_size) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) - using BlockReduce = cub::BlockReduce, BlockDim>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ U mean_share; - __shared__ U var_share; - - int beg_idx = blockIdx.x * feature_size + threadIdx.x; - int end_idx = (blockIdx.x + 1) * feature_size; - - // Step 1: Reduce to calculate mean and var - U mean_val = 0; - U var_val = 0; - for (int i = beg_idx; i < end_idx; i += BlockDim) { - U tmp = static_cast(x[i]); - mean_val += tmp; - var_val += (tmp * tmp); - } - auto pair = BlockReduce(temp_storage) - .Reduce(PairForLayerNorm(mean_val, var_val), - PairForLayerNormAddFunctor()); - if (threadIdx.x == 0) { - auto tmp = pair.first_ / static_cast(feature_size); - mean[blockIdx.x] = mean_share = static_cast(tmp); - var[blockIdx.x] = var_share = - static_cast(pair.second_ / static_cast(feature_size) - tmp * tmp); - } - __syncthreads(); - - mean_val = mean_share; - U invvar = rsqrt_(var_share + static_cast(epsilon)); - - // Step 2: Calculate y - if (scale != nullptr) { - if (bias != nullptr) { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; - i += BlockDim, j += BlockDim) { - y[i] = static_cast( - scale[j] * (static_cast(x[i]) - mean_val) * invvar + bias[j]); - } - } else { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; - i += BlockDim, j += BlockDim) { - y[i] = static_cast(scale[j] * (static_cast(x[i]) - mean_val) * - invvar); - } - } - } else { // scale == nullptr - if (bias != nullptr) { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; - i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar + - bias[j]); - } - } else { - for (int i = beg_idx, j = threadIdx.x; i < end_idx; - i += BlockDim, j += BlockDim) { - y[i] = static_cast((static_cast(x[i]) - mean_val) * invvar); - } - } - } -#endif -} - template __inline__ __device__ void cuLoadAddStridedInputs( const int64_t i1_block, const int thr_load_row_off, @@ -965,28 +898,6 @@ void LayerNormDirectCUDAFunctor::operator()(gpuStream_t stream, } } -template <> -void LayerNormDirectCUDAFunctor::operator()( - gpuStream_t stream, const half *input, std::vector input_shape, - const half *bias, const half *scale, half *output, half *mean, - half *variance, int begin_norm_axis, float eps) { - const auto x_dims = framework::make_ddim(input_shape); - auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); - int batch_size = static_cast(matrix_dim[0]); - int feature_size = static_cast(matrix_dim[1]); - switch (GetDesiredBlockDim(feature_size)) { - FIXED_BLOCK_DIM_CASE( - LayerNormForwardFP16<<>>( - input, scale, bias, output, mean, variance, eps, feature_size)); - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Product from begin_norm_axis to end in layer_norm must be larger " - "than 1")); - break; - } -} - template class LayerNormKernel : public framework::OpKernel { @@ -1076,9 +987,6 @@ class LayerNormGradKernel }; template class LayerNormDirectCUDAFunctor; -#ifdef TRT_PLUGIN_FP16_AVALIABLE -template class LayerNormDirectCUDAFunctor; -#endif #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE From 3a2230de8f615348966125ba94d8be9e6e647adb Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Tue, 15 Jun 2021 12:41:40 +0800 Subject: [PATCH 05/18] add conv3d prototxt (#33501) * add conv3d prototxt * Update conv3d.pbtxt * Update conv3d.pbtxt --- paddle/fluid/operators/compat/conv3d.pbtxt | 102 +++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 paddle/fluid/operators/compat/conv3d.pbtxt diff --git a/paddle/fluid/operators/compat/conv3d.pbtxt b/paddle/fluid/operators/compat/conv3d.pbtxt new file mode 100644 index 0000000000000..51d4c0d8e3bbb --- /dev/null +++ b/paddle/fluid/operators/compat/conv3d.pbtxt @@ -0,0 +1,102 @@ +type: "conv3d" +def { + inputs { + name: "Input" + } + inputs { + name: "Filter" + } + outputs { + name: "Output" + } + attrs { + name: "strides" + type: INTS + } + attrs { + name: "paddings" + type: INTS + } + attrs { + name: "padding_algorithm" + type: STRING + } + attrs { + name: "groups" + type: INT + } + attrs { + name: "dilations" + type: INTS + } + attrs { + name: "data_format" + type: STRING + } +} +extra { + inputs { + name: "ResidualData" + } + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "use_cudnn" + type: BOOLEAN + } + attrs { + name: "fuse_relu_before_depthwise_conv" + type: BOOLEAN + } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "use_quantizer" + type: BOOLEAN + } + attrs { + name: "mkldnn_data_type" + type: STRING + } + attrs { + name: "fuse_relu" + type: BOOLEAN + } + attrs { + name: "fuse_activation" + type: STRING + } + attrs { + name: "fuse_alpha" + type: FLOAT + } + attrs { + name: "fuse_beta" + type: FLOAT + } + attrs { + name: "use_addto" + type: BOOLEAN + } + attrs { + name: "fuse_residual_connection" + type: BOOLEAN + } + attrs { + name: "force_fp32_output" + type: BOOLEAN + } + attrs { + name: "workspace_size_MB" + type: INT + } + attrs { + name: "exhaustive_search" + type: BOOLEAN + } +} + From 009a163cf69d267ed4029b4f77f6a74af0ca4593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 15 Jun 2021 13:00:35 +0800 Subject: [PATCH 06/18] fix the op attrs error in conv2d pbtxt,test=develop (#33532) --- paddle/fluid/operators/compat/conv2d.pbtxt | 8 ++++---- paddle/fluid/operators/compat/conv2d_transpose.pbtxt | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/compat/conv2d.pbtxt b/paddle/fluid/operators/compat/conv2d.pbtxt index 94073800f7246..24f15098a8b9e 100644 --- a/paddle/fluid/operators/compat/conv2d.pbtxt +++ b/paddle/fluid/operators/compat/conv2d.pbtxt @@ -32,6 +32,10 @@ def { name: "dilations" type: INTS } + attrs { + name: "data_format" + type: STRING + } } extra { inputs { @@ -113,10 +117,6 @@ extra { name: "force_fp32_output" type: BOOLEAN } - attrs { - name: "data_format" - type: STRING - } attrs { name: "workspace_size_MB" type: INT diff --git a/paddle/fluid/operators/compat/conv2d_transpose.pbtxt b/paddle/fluid/operators/compat/conv2d_transpose.pbtxt index 7e3ecb22152b5..474043718e4f9 100644 --- a/paddle/fluid/operators/compat/conv2d_transpose.pbtxt +++ b/paddle/fluid/operators/compat/conv2d_transpose.pbtxt @@ -1,4 +1,4 @@ -type: "reduce_mean" +type: "conv2d_transpose" def { inputs { name: "Input" @@ -40,6 +40,10 @@ def { name: "padding_algorithm" type: STRING } + attrs { + name: "data_format" + type: STRING + } } extra { attrs { @@ -78,10 +82,6 @@ extra { name: "fuse_beta" type: FLOAT } - attrs { - name: "data_format" - type: STRING - } attrs { name: "workspace_size_MB" type: INT From 28521e0f710916d0f572d688b6c408a83a40e590 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Tue, 15 Jun 2021 14:31:29 +0800 Subject: [PATCH 07/18] Save all the information of 'ParamBase' in 'Layer'. (#33500) * Save all the information of 'ParamBase' in 'Layer'. * edit unittest --- python/paddle/fluid/framework.py | 12 ++++++ .../tests/unittests/test_paddle_save_load.py | 12 ++---- python/paddle/framework/io.py | 43 ++++++++++++++++--- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 695c91fea819f..22f31a340364f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5540,6 +5540,18 @@ def _copy_to(self, device, blocking): core.varbase_copy(self, new_param, device, blocking) return new_param + def __reduce__(self): + value = self.numpy() + state = (self.name, self.persistable, self.stop_gradient) + return ParamBase, (self.shape, self.dtype), (self.__dict__, value, + state) + + def __setstate__(self, state): + self.__dict__.update(state[0]) + t = self.value().get_tensor() + t.set(state[1], _current_expected_place()) + self.name, self.persistable, self.stop_gradient = state[2] + __repr__ = __str__ diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 594d0db035c6a..fe8692a38814e 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -935,21 +935,17 @@ def test_save_load_layer(self): layer2 = LinearNet() layer1.eval() layer2.eval() + origin_layer = (layer1, layer2) origin = (layer1(inps), layer2(inps)) path = "test_save_load_layer_/layer.pdmodel" - paddle.save((layer1, layer2), path) - - # static - paddle.enable_static() - with self.assertRaises(ValueError): - paddle.load(path) - # dygraph - paddle.disable_static() + paddle.save(origin_layer, path) loaded_layer = paddle.load(path) loaded_result = [l(inps) for l in loaded_layer] for i in range(len(origin)): self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10) + for k, v in origin_layer[i]._linear.weight.__dict__.items(): + self.assertTrue(v == loaded_layer[i]._linear.weight.__dict__[k]) if __name__ == '__main__': diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 5f1ffa81eab17..d02d078d547de 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -233,9 +233,13 @@ def _pickle_save(obj, f, protocol): raise ValueError("Expected 1<'protocol'<5, but received protocol={}". format(protocol)) - def reudce_varbase(self): + list_params = set() + + def reduce_varbase(self): data = self.numpy() name = self.name + if name in list_params: + return self.__reduce__() return (tuple, ((name, data), )) @@ -244,16 +248,43 @@ def reduce_LoDTensor(self): return (eval, ('data', {'data': data})) + def reduce_Layer(self): + is_param_or_layer = lambda v: isinstance(v, ParamBase) or isinstance(v, core.Layer) + + def collect_params(param_or_layer): + if isinstance(param_or_layer, ParamBase): + list_params.add(param_or_layer.name) + else: + # param_or_layer is layer + _parse_every_object(param_or_layer.__dict__, is_param_or_layer, + collect_params) + return param_or_layer + + _parse_every_object(self.__dict__, is_param_or_layer, collect_params) + return self.__reduce_ex__(protocol) + + dispatch_table_layer = dict() + + def create_layer_dispatch_table(layer): + dispatch_table_layer[layer.__class__] = reduce_Layer + return layer + + _parse_every_object(obj, lambda v: isinstance(v, core.Layer), + create_layer_dispatch_table) + def add_dispatch_table(): # This is not a good method, because the pickle module has been modified. - pickle.dispatch_table[core.VarBase] = reudce_varbase - pickle.dispatch_table[ParamBase] = reudce_varbase + pickle.dispatch_table[core.VarBase] = reduce_varbase + pickle.dispatch_table[ParamBase] = reduce_varbase pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor + pickle.dispatch_table.update(dispatch_table_layer) def pop_dispatch_table(): pickle.dispatch_table.pop(core.VarBase) pickle.dispatch_table.pop(core.LoDTensor) pickle.dispatch_table.pop(ParamBase) + for k in dispatch_table_layer: + pickle.dispatch_table.pop(k) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' if sys.platform == 'darwin' and sys.version_info.major == 3: @@ -273,10 +304,10 @@ def pop_dispatch_table(): pickler = pickle.Pickler(f, protocol) pickler.dispatch_table = copyreg.dispatch_table.copy() - pickler.dispatch_table[core.VarBase] = reudce_varbase + pickler.dispatch_table[core.VarBase] = reduce_varbase pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor - pickler.dispatch_table[ParamBase] = reudce_varbase - + pickler.dispatch_table[ParamBase] = reduce_varbase + pickler.dispatch_table.update(dispatch_table_layer) pickler.dump(obj) From ff8252387e5039f4bbc1201da38a7a956a562669 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 15 Jun 2021 16:55:27 +0800 Subject: [PATCH 08/18] [NPU] use SparseSoftmaxCrossEntropyWithLogits in npu kernel of softmax_with_cross_entropy (#32858) * use SparseSoftmaxCrossEntropyWithLogits * fix * test_slice * revert test_slice * add backprob for npu kernel * fix typo * fix ut * fix ut * refine comments * return softmax --- .../softmax_with_cross_entropy_op.cc | 41 ++++- .../softmax_with_cross_entropy_op_npu.cc | 168 ++++++------------ python/paddle/fluid/layers/loss.py | 23 ++- .../test_softmax_with_cross_entropy_op_npu.py | 19 +- .../white_list/no_check_set_white_list.py | 1 + 5 files changed, 122 insertions(+), 130 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index fbaf76d4e7cd8..0c2d39e7519ef 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -44,6 +44,19 @@ class SoftmaxWithCrossEntropyOpMaker "The outputs value of softmax activation by given the input batch, " "which will be used in backward calculation.") .AsIntermediate(); +#ifdef PADDLE_WITH_ASCEND_CL + AddOutput( + "Backprop", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits). " + "The intermediate value used for backward calculation. The calculation " + "is :" + "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " + "where labels is ont-hot." + "Currently, the tensor is generated and used in npu kernel only. ") + .AsIntermediate() + .AsDispensable(); +#endif AddOutput("Loss", "(Tensor, default: Tensor), A tensor in same shape with " "Input(Logits) " @@ -181,7 +194,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Softmax", logits_dims); - +#ifdef PADDLE_WITH_ASCEND_CL + ctx->SetOutputDim("Backprop", logits_dims); + ctx->ShareLoD("Logits", /*->*/ "Backprop"); +#endif logits_dims[axis] = 1; ctx->SetOutputDim("Loss", logits_dims); @@ -285,6 +301,9 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker { grad_op->SetType("softmax_with_cross_entropy_grad"); grad_op->SetInput("Label", this->Input("Label")); grad_op->SetInput("Softmax", this->Output("Softmax")); +#ifdef PADDLE_WITH_ASCEND_CL + grad_op->SetInput("Backprop", this->Output("Backprop")); +#endif grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); grad_op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits")); @@ -317,9 +336,29 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradKernel, ops::SoftmaxWithCrossEntropyGradKernel); + REGISTER_OP_VERSION(softmax_with_cross_entropy) +#ifdef PADDLE_WITH_ASCEND_CL + .AddCheckpoint( + R"ROC( + Add a new attribute [use_softmax] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_softmax", "A flag to indicate whether to do softmax", true)) + .AddCheckpoint( + R"ROC( + Add a new dispensable/intermediate output [backprop] )ROC", + paddle::framework::compatible::OpVersionDesc().NewOutput( + "Backprop", + "The intermediate value used for backward calculation. The " + "calculation is :" + "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " + "where labels is ont-hot." + "Currently, the tensor is generated and used in npu kernel " + "only. ")); +#else .AddCheckpoint( R"ROC( Add a new attribute [use_softmax] )ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( "use_softmax", "A flag to indicate whether to do softmax", true)); +#endif diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index 9921248d1ca1d..639fc6fcc2e79 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -32,81 +32,53 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel { auto* labels = ctx.Input("Label"); auto* softmax = ctx.Output("Softmax"); auto* loss = ctx.Output("Loss"); + auto* backprop = ctx.Output("Backprop"); + auto soft_label = ctx.Attr("soft_label"); + PADDLE_ENFORCE_EQ(soft_label, false, + platform::errors::Unimplemented( + "soft_label=True is not supported in " + "the npu kernel of softmax_with_cross_entropy.")); - int cls_num = logits->dims()[1]; const int rank = logits->dims().size(); const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - std::vector axes; - for (auto i = axis; i < logits->dims().size(); ++i) { - axes.push_back(i); - } + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + + PADDLE_ENFORCE_EQ( + labels->numel(), n, + platform::errors::Unimplemented( + "The size of labels should be equal to SizeToAxis of logits," + "but got size of labels is %d and SizeToAxis is %d.", + labels->numel(), n)); + + loss->mutable_data(ctx.GetPlace()); + backprop->mutable_data(ctx.GetPlace()); + softmax->mutable_data(ctx.GetPlace()); + + Tensor logits_2d, labels_1d, loss_1d, backprop_2d, softmax_2d; + logits_2d.ShareDataWith(*logits).Resize({n, d}); + labels_1d.ShareDataWith(*labels).Resize({n}); + loss_1d.ShareDataWith(*loss).Resize({n}); + backprop_2d.ShareDataWith(*backprop).Resize({n, d}); + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); auto stream = ctx.template device_context() .stream(); - // softmax - softmax->mutable_data(ctx.GetPlace()); + std::vector axes; + for (auto i = axis; i < logits->dims().size(); ++i) { + axes.push_back(i); + } const auto& runner_softmax = NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}}); runner_softmax.Run(stream); - // cast label from int64/int32 to int32 - Tensor tmp_labels(framework::proto::VarType::INT32); - if (labels->type() != framework::proto::VarType::INT32) { - tmp_labels.Resize(labels->dims()); - tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); - auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); - const auto& runner_cast_label = - NpuOpRunner("Cast", {*labels}, {tmp_labels}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_label.Run(stream); - labels = &tmp_labels; - } - - // on and off - Tensor on_tensor(framework::proto::VarType::INT32); - on_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&on_tensor, static_cast(1)); - Tensor off_tensor(framework::proto::VarType::INT32); - off_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&off_tensor, static_cast(0)); - - // one_hot - Tensor tmp_onehot(on_tensor.type()); - tmp_onehot.Resize(logits->dims()); - tmp_onehot.mutable_data(ctx.GetPlace()); - - const auto& runner_onehot = - NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, - {{"axis", -1}, {"depth", cls_num}}); - runner_onehot.Run(stream); - - // cast one_hot from int32 to T - Tensor cast_onehot(logits->type()); - cast_onehot.Resize(tmp_onehot.dims()); - cast_onehot.mutable_data(ctx.GetPlace()); - auto dst_dtype = ConvertToNpuDtype(logits->type()); - const auto& runner_cast_onehot = - NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_onehot.Run(stream); - - // SoftmaxCrossEntropyWithLogits - Tensor backprop(logits->type()); - backprop.Resize(logits->dims()); - backprop.mutable_data(ctx.GetPlace()); - - loss->mutable_data(ctx.GetPlace()); - - // SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size] - auto loss_dims = loss->dims(); - loss->Resize({loss_dims[0]}); + // SparseSoftmaxCrossEntropyWithLogits const auto& runner_s = - NpuOpRunner("SoftmaxCrossEntropyWithLogits", {*logits, cast_onehot}, - {*loss, backprop}, {}); + NpuOpRunner("SparseSoftmaxCrossEntropyWithLogits", + {logits_2d, labels_1d}, {loss_1d, backprop_2d}, {}); runner_s.Run(stream); - loss->Resize(loss_dims); } }; @@ -114,70 +86,32 @@ template class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* labels = ctx.Input("Label"); - auto* softmax = ctx.Input("Softmax"); + auto* backprop = ctx.Input("Backprop"); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); - int cls_num = softmax->dims()[1]; + PADDLE_ENFORCE_NOT_NULL(backprop, + platform::errors::PreconditionNotMet( + "backprop should not be null in NPU kernel of " + "softmax_with_cross_entropy_grad.")); + logits_grad->mutable_data(ctx.GetPlace()); + + const int rank = logits_grad->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int n = SizeToAxis(axis, logits_grad->dims()); + const int d = SizeFromAxis(axis, logits_grad->dims()); + + Tensor logits_grad_2d, loss_grad_1d, backprop_2d; + + logits_grad_2d.ShareDataWith(*logits_grad).Resize({n, d}); + loss_grad_1d.ShareDataWith(*loss_grad).Resize({n}); + backprop_2d.ShareDataWith(*backprop).Resize({n, d}); auto stream = ctx.template device_context() .stream(); - - // cast label from int64/int32 to int32 - Tensor tmp_labels(framework::proto::VarType::INT32); - if (labels->type() != framework::proto::VarType::INT32) { - tmp_labels.Resize(labels->dims()); - tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); - auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); - const auto& runner_cast_label = - NpuOpRunner("Cast", {*labels}, {tmp_labels}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_label.Run(stream); - labels = &tmp_labels; - } - - // on and off - Tensor on_tensor(framework::proto::VarType::INT32); - on_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&on_tensor, static_cast(1)); - Tensor off_tensor(framework::proto::VarType::INT32); - off_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&off_tensor, static_cast(0)); - - // one_hot - Tensor tmp_onehot(on_tensor.type()); - tmp_onehot.Resize(softmax->dims()); - tmp_onehot.mutable_data(ctx.GetPlace()); - - const auto& runner_onehot = - NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, - {{"axis", -1}, {"depth", cls_num}}); - runner_onehot.Run(stream); - - // cast one_hot from int32 to T - Tensor cast_onehot(softmax->type()); - cast_onehot.Resize(tmp_onehot.dims()); - cast_onehot.mutable_data(ctx.GetPlace()); - auto dst_dtype = ConvertToNpuDtype(softmax->type()); - const auto& runner_cast_onehot = - NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_onehot.Run(stream); - - // sub - Tensor tmp_sub(softmax->type()); - tmp_sub.Resize(softmax->dims()); - tmp_sub.mutable_data(ctx.GetPlace()); - const auto& runner_sub = - NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {}); - - runner_sub.Run(stream); - // mul - logits_grad->mutable_data(ctx.GetPlace()); const auto& runner_mul = - NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {}); + NpuOpRunner("Mul", {*loss_grad, *backprop}, {*logits_grad}, {}); runner_mul.Run(stream); } }; diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index c3f25dc53c12c..d150cc7a9aee9 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -26,6 +26,7 @@ from ..param_attr import ParamAttr from ..initializer import NumpyArrayInitializer, Constant from .. import core +import warnings __all__ = [ 'center_loss', @@ -1258,10 +1259,16 @@ def softmax_with_cross_entropy(logits, print(out) """ if in_dygraph_mode(): - softmax, loss = core.ops.softmax_with_cross_entropy( - logits, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis', - axis) + if core.is_compiled_with_npu(): + softmax, backprop, loss = core.ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) + else: + softmax, loss = core.ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) if not return_softmax: return loss else: @@ -1276,12 +1283,16 @@ def softmax_with_cross_entropy(logits, helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + + outputs = {'Softmax': softmax, 'Loss': loss} + if core.is_compiled_with_npu(): + backprop = helper.create_variable_for_type_inference(dtype=logits.dtype) + outputs['Backprop'] = backprop helper.append_op( type='softmax_with_cross_entropy', inputs={'Logits': logits, 'Label': label}, - outputs={'Softmax': softmax, - 'Loss': loss}, + outputs=outputs, attrs=attrs) if return_softmax: diff --git a/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py index 1b48268b0e77e..2ee089360e6dd 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py @@ -68,8 +68,11 @@ def setUp(self): loss = cross_entropy(softmax, labels, self.soft_label, self.axis, self.ignore_index) + one_hot_label = np.eye(axis_dim)[labels.reshape(-1)] + self.inputs = {"Logits": logits, "Label": labels} self.outputs = { + "Backprop": (softmax - one_hot_label).astype(self.dtype), "Softmax": softmax.astype(self.dtype), "Loss": loss.astype(self.dtype) } @@ -85,12 +88,16 @@ def setUp(self): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) - # TODO(ascendrc): Add grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + def test_check_grad(self): + if self.dtype == np.float16: + return + # fp32 has low precision, cpu and npu both need to relax the max_relative_error if using fp32 + self.check_grad_with_place( + self.place, ['Logits'], + 'Loss', + check_dygraph=False, + numeric_grad_delta=0.001, + max_relative_error=0.5) @unittest.skipIf(not paddle.is_compiled_with_npu(), diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index f81011717040a..32ac4f412a8f5 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -30,4 +30,5 @@ 'cudnn_lstm', 'rnn', 'fusion_lstm', + 'softmax_with_cross_entropy', ] From e47c3f040f3ff6a2866fe3675b19e657ef3c3115 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 15 Jun 2021 21:07:47 +0800 Subject: [PATCH 09/18] [XPU] Update cmake options for xpu. (#33450) --- cmake/external/lite.cmake | 30 ++++++++++++++++++++++-------- cmake/external/xpu.cmake | 11 +++++------ python/setup.py.in | 12 ------------ 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/cmake/external/lite.cmake b/cmake/external/lite.cmake index 6e2157e308716..c48f2cb0467cf 100644 --- a/cmake/external/lite.cmake +++ b/cmake/external/lite.cmake @@ -18,13 +18,21 @@ if(NOT LINUX) return() endif() -if(XPU_SDK_ROOT) - set(LITE_WITH_XPU ON) - include_directories("${XPU_SDK_ROOT}/XTDK/include") - include_directories("${XPU_SDK_ROOT}/XTCL/include") +if (LITE_WITH_XPU) add_definitions(-DLITE_SUBGRAPH_WITH_XPU) - LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/shlib/") - LINK_DIRECTORIES("${XPU_SDK_ROOT}/XTDK/runtime/shlib/") + IF(WITH_AARCH64) + SET(XPU_SDK_ENV "kylin_aarch64") + ELSEIF(WITH_SUNWAY) + SET(XPU_SDK_ENV "deepin_sw6_64") + ELSEIF(WITH_BDCENTOS) + SET(XPU_SDK_ENV "bdcentos_x86_64") + ELSEIF(WITH_UBUNTU) + SET(XPU_SDK_ENV "ubuntu_x86_64") + ELSEIF(WITH_CENTOS) + SET(XPU_SDK_ENV "centos7_x86_64") + ELSE () + SET(XPU_SDK_ENV "ubuntu_x86_64") + ENDIF() endif() if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) @@ -57,7 +65,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) -DWITH_TESTING=OFF -DLITE_BUILD_EXTRA=ON -DLITE_WITH_XPU=${LITE_WITH_XPU} - -DXPU_SDK_ROOT=${XPU_SDK_ROOT} + -DXPU_SDK_URL=${XPU_BASE_URL} + -DXPU_SDK_ENV=${XPU_SDK_ENV} -DLITE_WITH_CODE_META_INFO=OFF -DLITE_WITH_ARM=ON) ExternalProject_Add( @@ -99,7 +108,8 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR) -DLITE_WITH_STATIC_CUDA=OFF -DCUDA_ARCH_NAME=${CUDA_ARCH_NAME} -DLITE_WITH_XPU=${LITE_WITH_XPU} - -DXPU_SDK_ROOT=${XPU_SDK_ROOT} + -DXPU_SDK_URL=${XPU_SDK_URL} + -DXPU_SDK_ENV=${XPU_SDK_ENV} -DLITE_WITH_CODE_META_INFO=OFF -DLITE_WITH_ARM=OFF) @@ -147,6 +157,10 @@ message(STATUS "Paddle-lite BINARY_DIR: ${LITE_BINARY_DIR}") message(STATUS "Paddle-lite SOURCE_DIR: ${LITE_SOURCE_DIR}") include_directories(${LITE_SOURCE_DIR}) include_directories(${LITE_BINARY_DIR}) +if(LITE_WITH_XPU) + include_directories(${LITE_BINARY_DIR}/third_party/install/xpu/xdnn/include/) + include_directories(${LITE_BINARY_DIR}/third_party/install/xpu/xre/include/) +endif() function(external_lite_libs alias path) add_library(${alias} SHARED IMPORTED GLOBAL) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 5d1f1776f885c..a8c33618a6135 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -33,7 +33,10 @@ ELSE () SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") ENDIF() -SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527") +IF(NOT XPU_BASE_URL) + SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527") +ENDIF() + SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) @@ -93,11 +96,7 @@ ELSE(WITH_XPU_BKCL) TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB}) ENDIF(WITH_XPU_BKCL) -if(NOT XPU_SDK_ROOT) - ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) -else() - ADD_CUSTOM_TARGET(extern_xpu DEPENDS xpulib) -endif() +ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) # Ensure that xpu/api.h can be included without dependency errors. file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/.xpu_headers_dummy.cc CONTENT "") diff --git a/python/setup.py.in b/python/setup.py.in index 866c2b400d5ca..787317acb6d44 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -347,18 +347,6 @@ if '${WITH_XPU_BKCL}' == 'ON': shutil.copy('${XPU_BKCL_LIB}', libs_path) package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}'] -# Only for lite xpu inference. -if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '': - xpu_api_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/shlib/', 'libxpuapi.so') - xpu_rt_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/runtime/shlib/', 'libxpurt.so') - if os.path.exists(xpu_api_lib): - shutil.copy(xpu_api_lib, libs_path) - package_data['paddle.libs']+=['libxpuapi.so'] - if os.path.exists(xpu_rt_lib): - shutil.copy(xpu_rt_lib, libs_path) - package_data['paddle.libs']+=['libxpurt.so'] - - # remove unused paddle/libs/__init__.py if os.path.isfile(libs_path+'/__init__.py'): os.remove(libs_path+'/__init__.py') From b7a54fc1514b976da3213b8db5c8f3ebcae5371d Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Tue, 15 Jun 2021 23:40:37 +0800 Subject: [PATCH 10/18] support convert core.Tensor to paddle.Tensor (#33430) --- paddle/fluid/pybind/imperative.cc | 2 +- .../fluid/tests/unittests/test_var_base.py | 19 ++++++++++++++++++- python/paddle/tensor/creation.py | 7 ++++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5b9b492e64905..816281ce8a00d 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -245,7 +245,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, } static void InitVarBaseFromTensorWithArgDefault( - imperative::VarBase *self, const framework::LoDTensor &tensor) { + imperative::VarBase *self, const framework::Tensor &tensor) { VLOG(4) << "Init VarBase"; auto place = imperative::GetCurrentTracer()->ExpectedPlace(); new (self) imperative::VarBase( diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index b8d29d482fefa..be7b751115581 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -176,7 +176,6 @@ def _test_place(place): x = paddle.to_tensor(1, dtype='uint8') self.assertEqual(x.item(), 1) - print(type(x.item())) self.assertTrue(isinstance(x.item(), int)) x = paddle.to_tensor(1, dtype='int8') @@ -203,6 +202,24 @@ def _test_place(place): self.assertEqual(x.item(), 1 + 1j) self.assertTrue(isinstance(x.item(), complex)) + numpy_array = np.random.randn(3, 4) + # covert core.LoDTensor to paddle.Tensor + lod_tensor = paddle.fluid.core.LoDTensor() + place = paddle.fluid.framework._current_expected_place() + lod_tensor.set(numpy_array, place) + x = paddle.to_tensor(lod_tensor) + self.assertTrue(np.array_equal(x.numpy(), numpy_array)) + self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + self.assertEqual(str(x.place), str(place)) + + # covert core.Tensor to paddle.Tensor + x = paddle.to_tensor(numpy_array) + dlpack = x.value().get_tensor()._to_dlpack() + tensor_from_dlpack = paddle.fluid.core.from_dlpack(dlpack) + x = paddle.to_tensor(tensor_from_dlpack) + self.assertTrue(np.array_equal(x.numpy(), numpy_array)) + self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + with self.assertRaises(ValueError): paddle.randn([3, 2, 2]).item() with self.assertRaises(ValueError): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b446a5921b067..dba4cc1dd8ce9 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -136,9 +136,10 @@ def _handle_dtype(data, dtype): data = data._copy_to(place, False) ata = _handle_dtype(data, dtype) data.stop_gradient = stop_gradient - elif isinstance(data, core.LoDTensor): - # convert LoDTensor to VarBase first - # Currenly, LoDTensor does no copy when places are same + elif isinstance(data, (core.LoDTensor, core.Tensor)): + # Note(zhouwei25): should't expose it to users, just for internal use. + # convert core.Tensor/core.LoDTensor to VarBase first + # Currenly, there is no copy when places are same data = paddle.Tensor(data) if not data.place._equals(place): data = data._copy_to(place, False) From ec6d5efe5943529431b7901bbcdcda601e662f54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 16 Jun 2021 10:24:33 +0800 Subject: [PATCH 11/18] enhance the attribute constraint for pass,test=develop (#33568) --- paddle/fluid/framework/ir/op_compat_sensible_pass.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index e422a9bae3118..496d06cc331ca 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -75,9 +75,6 @@ AttrCompat& AttrCompat::IsLeftDefault() { } bool AttrCompat::operator()(const OpDesc& op_desc) { - if (conditions_.empty()) { - return true; - } if (!op_desc.HasAttr(attr_name_)) { if (!optional_) { LOG(WARNING) << "The non-optional Attr(" << attr_name_ << ") of Op (" From 969ad85f42102984775064cfa259ee6c13454934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 16 Jun 2021 10:24:46 +0800 Subject: [PATCH 12/18] fix the error in batch_norm.pbtxt, test=develop (#33572) --- .../fluid/operators/compat/batch_norm.pbtxt | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/compat/batch_norm.pbtxt b/paddle/fluid/operators/compat/batch_norm.pbtxt index c18b4dc19dc2e..772d66f00fcc9 100644 --- a/paddle/fluid/operators/compat/batch_norm.pbtxt +++ b/paddle/fluid/operators/compat/batch_norm.pbtxt @@ -18,6 +18,21 @@ def { outputs { name: "Y" } + outputs { + name: "MeanOut" + } + outputs { + name: "VarianceOut" + } + outputs { + name: "SavedMean" + } + outputs { + name: "SavedVariance" + } + outputs { + name: "ReserveSpace" + } attrs { name: "epsilon" type: FLOAT @@ -55,21 +70,6 @@ extra { name: "trainable_statistics" type: BOOLEAN } - outputs { - name: "MeanOut" - } - outputs { - name: "VarianceOut" - } - outputs { - name: "SavedMean" - } - outputs { - name: "SavedVariance" - } - outputs { - name: "ReserveSpace" - } attrs { name: "op_role" type: INT From 07197fb9b33487cc1633623b9681ff8778fde03a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 16 Jun 2021 11:17:25 +0800 Subject: [PATCH 13/18] add_op_extra: elementwise_add, mul (#33491) --- paddle/fluid/operators/compat/elementwise_add.pbtxt | 8 ++++++++ paddle/fluid/operators/compat/mul.pbtxt | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/paddle/fluid/operators/compat/elementwise_add.pbtxt b/paddle/fluid/operators/compat/elementwise_add.pbtxt index 3e96147ef88eb..6a3d0a9b3a131 100644 --- a/paddle/fluid/operators/compat/elementwise_add.pbtxt +++ b/paddle/fluid/operators/compat/elementwise_add.pbtxt @@ -15,6 +15,14 @@ def { } } extra { + attrs { + name: "out_threshold" + type: FLOAT + } + attrs { + name: "Out0_threshold" + type: FLOAT + } attrs { name: "x_data_format" type: STRING diff --git a/paddle/fluid/operators/compat/mul.pbtxt b/paddle/fluid/operators/compat/mul.pbtxt index b40c05ad2e033..617775eaaae9e 100644 --- a/paddle/fluid/operators/compat/mul.pbtxt +++ b/paddle/fluid/operators/compat/mul.pbtxt @@ -19,6 +19,18 @@ def { } } extra { + attrs { + name: "Out0_threshold" + type: FLOAT + } + attrs { + name: "bit_length" + type: INT + } + attrs { + name: "quantization_type" + type: STRING + } attrs { name: "skip_quant" type: BOOLEAN From 294dfd23e44a7da086f3766c42e3e5278d0e9649 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Wed, 16 Jun 2021 11:51:13 +0800 Subject: [PATCH 14/18] [HybridParallel]Add SharedLayerDesc for PipelineParallel (#33578) * add pplayer * add sharedlayerdesc --- python/paddle/distributed/collective.py | 4 +- .../paddle/distributed/fleet/base/topology.py | 12 +- .../fleet/meta_parallel/__init__.py | 1 + .../meta_parallel/parallel_layers/__init__.py | 1 + .../parallel_layers/pp_layers.py | 105 ++++++++ .../fleet/meta_parallel/pipeline_parallel.py | 2 + .../unittests/hybrid_parallel_mp_layers.py | 4 +- .../hybrid_parallel_shared_weight.py | 233 ++++++++++++++++++ ...test_parallel_dygraph_pipeline_parallel.py | 3 + 9 files changed, 358 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_shared_weight.py diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 0ffb1d9f881ba..1a09cf5394fba 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -267,7 +267,9 @@ def new_group(ranks=None, backend=None): # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by cross-creation of new_group - tmp = fill_constant([0], dtype="int32", value="1") + tmp = paddle.to_tensor( + [1], dtype="int32") if in_dygraph_mode() else fill_constant( + [0], dtype="int32", value="1") paddle.distributed.all_reduce(tmp, use_calc_stream=True) paddle.distributed.wait(tmp) return gp diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 04d8417fdcbf3..850f358142170 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -107,6 +107,11 @@ def get_comm_list(self, axis_name): return all_result + def get_rank_from_stage(self, global_rank, **kwargs): + coord = self.get_coord(global_rank) + tf = coord._replace(**kwargs)._asdict() + return self.get_rank(**tf) + class HybridCommunicateGroup(object): def __init__(self, topology): @@ -254,7 +259,6 @@ def get_pipe_parallel_group(self): def get_check_parallel_group(self): return self._check_comm_group - def get_rank_from_stage(self, stage_id): - coord = self._topo.get_coord(self.global_rank) - tf = coord._replace(pipe=stage_id)._asdict() - return self._topo.get_rank(**tf) + def get_rank_from_stage(self, stage_id, **kwargs): + return self._topo.get_rank_from_stage( + self.global_rank, pipe=stage_id, **kwargs) diff --git a/python/paddle/distributed/fleet/meta_parallel/__init__.py b/python/paddle/distributed/fleet/meta_parallel/__init__.py index 0750c2c250e2b..4e32ff5723c41 100644 --- a/python/paddle/distributed/fleet/meta_parallel/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/__init__.py @@ -17,6 +17,7 @@ from .parallel_layers import RowParallelLinear # noqa: F401 from .parallel_layers import ParallelCrossEntropy # noqa: F401 from .parallel_layers import LayerDesc # noqa: F401 +from .parallel_layers import SharedLayerDesc # noqa: F401 from .parallel_layers import PipelineLayer # noqa: F401 from .parallel_layers import RNGStatesTracker # noqa: F401 from .parallel_layers import model_parallel_random_seed # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py index 72da962b8914e..fd97785749073 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/__init__.py @@ -17,6 +17,7 @@ from .mp_layers import RowParallelLinear # noqa: F401 from .mp_layers import ParallelCrossEntropy # noqa: F401 from .pp_layers import LayerDesc # noqa: F401 +from .pp_layers import SharedLayerDesc # noqa: F401 from .pp_layers import PipelineLayer # noqa: F401 from .random import RNGStatesTracker # noqa: F401 from .random import model_parallel_random_seed # noqa: F401 diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 77be62ae6cf4b..b31b2939695b3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -15,6 +15,7 @@ import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str +from functools import partial __all__ = [] @@ -58,6 +59,20 @@ def __repr__(self): **self.kwargs) +class SharedLayerDesc(LayerDesc): + def __init__(self, + key, + layer_func, + forward_func=None, + shared_weight_attr='weight', + *inputs, + **kwargs): + super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs) + self.layer_name = key + self.forward_func = forward_func + self.shared_weight_attr = shared_weight_attr + + class PipelineLayer(Layer): def __init__(self, layers, @@ -104,11 +119,86 @@ def __init__(self, self._start_pos = 0 self._end_pos = self._num_layers - 1 self._segment_network(seg_method) + self.shared_layers = paddle.nn.LayerDict() + self.shared_weight_attrs = {} # construct layer self.run_function = [] self._build_layer() + self.shared_comm = self._construct_shared_comm() + self._synchronize_shared_weights() + + def get_stage_from_index(self, layer_idx): + assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound" + for stage in range(self._topo.get_dim('pipe')): + if self.segment_parts[stage] <= layer_idx < self.segment_parts[stage + + 1]: + return stage + + def _construct_shared_comm(self): + shared_comm = {} + if self._topo.get_dim("pipe") == 1: + return + + layers_desc = self._layers_desc + shared_layer_names = set( + s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc)) + for key in shared_layer_names: + shared_layers = [] + for idx, layer in enumerate(layers_desc): + if isinstance(layer, + SharedLayerDesc) and layer.layer_name == key: + shared_layers.append(idx) + + shared_stages = set( + self.get_stage_from_index(idx) for idx in shared_layers) + self._dp_degree = self._topo.get_dim('data') + self._mp_degree = self._topo.get_dim('model') + + shared_ranks = [] + for dp in range(self._dp_degree): + for mp in range(self._mp_degree): + shared_ranks = [] + for s in sorted(shared_stages): + shared_ranks.append( + self._topo.get_rank_from_stage( + self.global_rank, pipe=s, data=dp, model=mp)) + + group = paddle.distributed.new_group(ranks=shared_ranks) + if self.global_rank in shared_ranks: + assert key in self.shared_layers + if key in self.shared_layers: + shared_comm[key] = { + 'ranks': shared_ranks, + 'group': group, + 'weight_attr': self.shared_weight_attrs[key], + 'layer': self.shared_layers[key], + } + return shared_comm + + def _synchronize_shared_weights(self): + for key, comm in self.shared_comm.items(): + with paddle.framework.no_grad(): + paddle.distributed.broadcast( + getattr(comm['layer'], comm['weight_attr']), + src=min(comm['ranks']), + group=comm['group']) + + def allreduce_shared_weight_gradients(self): + for key, comm in self.shared_comm.items(): + param = getattr(self.shared_layers[key], comm['weight_attr']) + # need use trace_op to allreduce weight + with paddle.framework.no_grad(): + paddle.fluid.framework._dygraph_tracer().trace_op( + type="c_allreduce_sum", + inputs={'X': param._grad_ivar()}, + outputs={'Out': param._grad_ivar()}, + attrs={ + 'ring_id': comm['group'].id, + 'use_calc_stream': True + }) + def _segment_network(self, seg_method): logger.info("start segment network..") seg = SegmentLayers( @@ -142,6 +232,21 @@ def _build_layer(self): if isinstance(layer, Layer): self.run_function.append(layer) self.add_sublayer(str(layer_index), layer) + elif isinstance(layer, SharedLayerDesc): + if layer.layer_name not in self.shared_layers: + self.shared_layers[layer.layer_name] = layer.build_layer() + self.shared_weight_attrs[ + layer.layer_name] = layer.shared_weight_attr + + if layer.forward_func is None: + self.run_function.append(self.shared_layers[ + layer.layer_name]) + + else: + self.run_function.append( + partial(layer.forward_func, self.shared_layers[ + layer.layer_name])) + elif isinstance(layer, LayerDesc): model = layer.build_layer() self.run_function.append(model) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 54324b389336d..0bb6315290ed7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -138,6 +138,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None): self._backward(cache_id=backward_steps) backward_steps += 1 + self._layers.allreduce_shared_weight_gradients() + # optimizer self._step() self.train_loss = self._reduce_final_loss() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py index 23dae31738691..317eb14ad069e 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_mp_layers.py @@ -270,8 +270,8 @@ def test_parallel_embedding(self): np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) def test_parallel_cross_entropy(self): - batch_size = 2 - seq_length = 1 + batch_size = 8 + seq_length = 16 class_size_per_card = 2 vocab_size = class_size_per_card * self.model_parallel_size seed = 1025 diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_shared_weight.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_shared_weight.py new file mode 100644 index 0000000000000..9253f737bf942 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_shared_weight.py @@ -0,0 +1,233 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid.dygraph.container import Sequential +from paddle.distributed.fleet.meta_parallel import PipelineLayer +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn +import paddle.fluid as fluid +from paddle.distributed.fleet.meta_parallel import LayerDesc, SharedLayerDesc + + +def print_hook_fn(grad): + print(grad) + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 8 +micro_batch_size = 2 +vocab_size = 128 +hidden_size = 16 + + +class SimpleNet(Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + + self.softmax_weight = self.create_parameter( + shape=[hidden_size, vocab_size]) + self.softmax_bias = self.create_parameter( + shape=[vocab_size], is_bias=False) + + def forward(self, x1, x2, y1): + x_emb = self.word_embeddings(x1) + fc = fluid.layers.matmul(x_emb, self.softmax_weight) + fc = fluid.layers.elementwise_add(fc, self.softmax_bias) + projection = fluid.layers.reshape(fc, shape=[-1, vocab_size]) + + projection = paddle.matmul(projection, self.word_embeddings.weight) + + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=y1, soft_label=False) + return loss.mean() + + +class EmbeddingPipe(Layer): + def __init__(self): + super(EmbeddingPipe, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + + @property + def embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, args): + x1, x2 = args + x_emb = self.word_embeddings(x1) + return x_emb, x2 + + +class MatmulNet(Layer): + def __init__(self): + super(MatmulNet, self).__init__() + self.softmax_weight = self.create_parameter( + shape=[hidden_size, vocab_size]) + + def forward(self, args): + x1, x2 = args + fc = fluid.layers.matmul(x1, self.softmax_weight) + + return fc, x2 + + +class BiasNet(Layer): + def __init__(self): + super(BiasNet, self).__init__() + self.softmax_bias = self.create_parameter(shape=[vocab_size]) + + def forward(self, args): + fc, x2 = args + fc = fluid.layers.elementwise_add(fc, self.softmax_bias) + projection = fluid.layers.reshape(fc, shape=[-1, vocab_size]) + return projection, x2 + + +class LossNet(Layer): + def __init__(self): + super(LossNet, self).__init__() + + def forward(self, args, y1): + projection = args + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=y1[0], soft_label=False) + return loss.mean() + + +class SimpleNetPipe(PipelineLayer): + def __init__(self, **kwargs): + self.descs = [] + self.descs.append( + SharedLayerDesc( + 'embed', EmbeddingPipe, shared_weight_attr='embedding_weight')) + self.descs.append(LayerDesc(MatmulNet)) + + self.descs.append(LayerDesc(BiasNet)) + + def _logits_helper(embedding, output): + return paddle.matmul(output[0], embedding.embedding_weight) + + self.descs.append( + SharedLayerDesc( + 'embed', + EmbeddingPipe, + forward_func=_logits_helper, + shared_weight_attr='embedding_weight')) + + super(SimpleNetPipe, self).__init__( + layers=self.descs, loss_fn=LossNet(), **kwargs) + + +class TestDistEmbeddingTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + #construct model a + model_a = SimpleNet() + scheduler_a = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True) + optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, + parameters=model_a.parameters()) + + model_b = SimpleNetPipe(topology=hcg.topology()) + + scheduler_b = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04], verbose=True) + optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, + parameters=model_b.parameters()) + model_b = fleet.distributed_model(model_b) + optimizer_b = fleet.distributed_optimizer(optimizer_b) + + param_len = len(model_a.parameters()) + + parameters = [] + for param in model_a.parameters(): + parameters.append(param.numpy()) + + model_b_params = model_b.parameters() + + if pp_id == 0: + model_b_params[0].set_value(parameters[2]) + model_b_params[1].set_value(parameters[0]) + + else: + model_b_params[0].set_value(parameters[2]) + model_b_params[1].set_value(parameters[1]) + + for step in range(5): + x1_data = np.random.randint(0, vocab_size, size=[batch_size, 1]) + x2_data = np.random.randint(0, vocab_size, size=[batch_size, 1]) + y1_data = np.random.randint(0, hidden_size, size=[batch_size, 1]) + + x1 = paddle.to_tensor(x1_data) + x2 = paddle.to_tensor(x2_data) + y1 = paddle.to_tensor(y1_data) + + x1.stop_gradient = True + x2.stop_gradient = True + y1.stop_gradient = True + + loss_a = model_a(x1, x2, y1) + loss_a.backward() + + optimizer_a.step() + optimizer_a.clear_grad() + scheduler_a.step() + + loss_b = model_b.train_batch([(x1, x2), (y1, )], optimizer_b, + scheduler_b) + + print("loss", loss_a.numpy(), loss_b.numpy()) + np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 1d06e168208b2..ef8ee2e4ad445 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -27,6 +27,9 @@ def test_hybrid_parallel_pp_layer(self): def test_hybrid_parallel_pp_tuple_inputs(self): self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py') + def test_hybrid_parallel_pp_tuple_inputs(self): + self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') + if __name__ == "__main__": unittest.main() From 67a4de68bbd15e8e464f13fe5062e88cedd50d1a Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 16 Jun 2021 11:53:19 +0800 Subject: [PATCH 15/18] Add return when input is tensor (#33570) * add return when input is tensor * fix typo --- python/paddle/tensor/creation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index dba4cc1dd8ce9..b7c55ea424c71 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -134,8 +134,9 @@ def _handle_dtype(data, dtype): ) elif isinstance(data, paddle.Tensor): data = data._copy_to(place, False) - ata = _handle_dtype(data, dtype) + data = _handle_dtype(data, dtype) data.stop_gradient = stop_gradient + return data elif isinstance(data, (core.LoDTensor, core.Tensor)): # Note(zhouwei25): should't expose it to users, just for internal use. # convert core.Tensor/core.LoDTensor to VarBase first @@ -145,6 +146,7 @@ def _handle_dtype(data, dtype): data = data._copy_to(place, False) data = _handle_dtype(data, dtype) data.stop_gradient = stop_gradient + return data else: raise TypeError( "Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor". From e6c5282e5d6f76e7e0d6c24115f4a6639fbd7637 Mon Sep 17 00:00:00 2001 From: Jiangxinz Date: Wed, 16 Jun 2021 13:58:03 +0800 Subject: [PATCH 16/18] fix used before assign (#33519) --- .../tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py | 3 ++- python/paddle/fluid/tests/unittests/test_hdfs1.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py index fa9a93452dffd..ef26a27d05e1b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py @@ -99,13 +99,14 @@ def setUp(self): if self.with_bias: self.inputs['Bias'] = bias + h0_bf16 = convert_float_to_uint16(h0_fp32) + if self.with_h0: if self.weights_dtype == 'bf16': self.inputs['H0'] = h0_bf16 elif self.weights_dtype == 'fp32': self.inputs['H0'] = h0_fp32 - h0_bf16 = convert_float_to_uint16(h0_fp32) self.outputs = {'Hidden': (hidden, self.lod)} self.attrs = { diff --git a/python/paddle/fluid/tests/unittests/test_hdfs1.py b/python/paddle/fluid/tests/unittests/test_hdfs1.py index 1aac1236156ca..65d12c31e39ab 100644 --- a/python/paddle/fluid/tests/unittests/test_hdfs1.py +++ b/python/paddle/fluid/tests/unittests/test_hdfs1.py @@ -39,6 +39,7 @@ def test_timeout(self): fs.mkdirs(dst) fs.mkdirs(dst + "/" + src) output = "" + cmd = "{} -mv {} {}".format(fs._base_cmd, src, dst) try: fs.mv(src, dst, test_exists=False) self.assertFalse(1, "can't execute cmd:{} output:{}".format(cmd, @@ -46,7 +47,6 @@ def test_timeout(self): except FSTimeOut as e: print("execute mv {} to {} timeout".format(src, dst)) - cmd = "{} -mv {} {}".format(fs._base_cmd, src, dst) ret, output = fluid.core.shell_execute_cmd(cmd, 6 * 1000, 2 * 1000) self.assertNotEqual(ret, 0) print("second mv ret:{} output:{}".format(ret, output)) From 78260ff32d7f00be6ec1ecd846a84e4eacd0b596 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 16 Jun 2021 14:00:34 +0800 Subject: [PATCH 17/18] fix output_padding in conv (#33585) * fix output padding conv * add repr unittest for conv --- .../tests/unittests/test_conv2d_transpose_op.py | 13 +++++++++++++ python/paddle/nn/layer/conv.py | 12 ++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index 4e582d74c24a2..b106f7aa9c1c8 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -18,6 +18,7 @@ import numpy as np import paddle +import paddle.nn as nn paddle.enable_static() import paddle.fluid.core as core import paddle.fluid as fluid @@ -898,5 +899,17 @@ def attr_padding_with_data_format(): self.assertRaises(ValueError, attr_padding_with_data_format) +class TestConv2DTransposeRepr(unittest.TestCase): + def test_case(self): + paddle.disable_static() + x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) + conv = nn.Conv2DTranspose(4, 6, (3, 3), output_padding=1, stride=2) + print(conv) + y_var = conv(x_var) + y_np = y_var.numpy() + self.assertIsNotNone(y_np) + paddle.enable_static() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index fc98157273447..76011aeff5b4f 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -98,7 +98,7 @@ def __init__(self, 'kernel_size') self._padding = padding self._padding_mode = padding_mode - self._output_padding = output_padding + self.output_padding = output_padding if dims != 1: self._updated_padding, self._padding_algorithm = _update_padding_nd( padding, channel_last, dims) @@ -163,8 +163,8 @@ def extra_repr(self): main_str += ', padding={_padding}' if self._padding_mode is not 'zeros': main_str += ', padding_mode={_padding_mode}' - if self._output_padding != 0: - main_str += ', output_padding={_output_padding}' + if self.output_padding != 0: + main_str += ', output_padding={output_padding}' if self._dilation != [1] * len(self._dilation): main_str += ', dilation={_dilation}' if self._groups != 1: @@ -508,7 +508,7 @@ def forward(self, x, output_size=None): self.weight, bias=self.bias, output_size=output_size, - output_padding=self._output_padding, + output_padding=self.output_padding, padding=self._padding, stride=self._stride, dilation=self._dilation, @@ -824,7 +824,7 @@ def __init__(self, def forward(self, x, output_size=None): if output_size is None: - output_padding = self._output_padding + output_padding = self.output_padding else: output_padding = 0 @@ -1161,7 +1161,7 @@ def __init__(self, def forward(self, x, output_size=None): if output_size is None: - output_padding = self._output_padding + output_padding = self.output_padding else: output_padding = 0 From 32e3353fc37aaa9a8e020c63e1b43b1b0d236c35 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 16 Jun 2021 14:27:19 +0800 Subject: [PATCH 18/18] [Dy2Stat] Fix always copy by paddle.to_tensor from PR #33335(#33590) --- .../dygraph_to_static/partial_program.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 7910e7a38558c..84bac98013ade 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -247,28 +247,27 @@ def _prepare(self, inputs): flatten_inputs = flatten(inputs) # Convert variable into VarBase and feed in training data. input_vars = [] + expected_place = framework._current_expected_place() for i, value in enumerate(flatten_inputs): if isinstance(value, np.ndarray): var = core.VarBase( value=value, name=self._inputs[i].desc.name(), persistable=False, - place=framework._current_expected_place(), + place=expected_place, zero_copy=True) elif isinstance(value, core.VarBase): - value.name = self._inputs[i].desc.name() - if value.stop_gradient: - # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times - # into CUDAPlace when it's as input of multi Ops. so we move it in advance - # to avoid this problem. - var = paddle.to_tensor( - value, - dtype=value.dtype, - place=framework._current_expected_place(), - stop_gradient=True) + # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times + # into CUDAPlace when it's as input of multi Ops. so we move it in advance + # to avoid this problem. + if value.stop_gradient and not value.place._equals( + expected_place): + var = value._copy_to(expected_place, False) + var.stop_gradient = True var.name = value.name else: var = value + var.name = self._inputs[i].desc.name() else: continue input_vars.append(var)