diff --git a/paddle/fluid/operators/share_data_op.cc b/paddle/fluid/operators/share_data_op.cc new file mode 100644 index 0000000000000..6fcc29e900261 --- /dev/null +++ b/paddle/fluid/operators/share_data_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/share_data_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ShareDataOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ShareData"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ShareData"); + auto in_type = ctx->GetInputsVarType("X")[0]; + auto out_type = ctx->GetOutputsVarType("Out")[0]; + + PADDLE_ENFORCE_EQ( + in_type == framework::proto::VarType::LOD_TENSOR || + in_type == framework::proto::VarType::SELECTED_ROWS, + true, platform::errors::InvalidArgument( + "Type of Variable[X] must be LoDTensor or SelectedRows!")); + PADDLE_ENFORCE_EQ( + in_type, out_type, + platform::errors::InvalidArgument( + "The type of input (X) and output (Out) are inconsistent.")); + + ctx->ShareDim("X", "Out"); + } +}; + +class ShareDataOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of share_data op"); + AddOutput("Out", "(Tensor), The output tensor of share_data op"); + AddComment(R"DOC( +ShareData Operator. + +Return a tensor $Out$ that shares data with the input tensor $X$ and without tensor copy. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + share_data, ops::ShareDataOp, ops::ShareDataOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(share_data, ops::ShareDataKernel, + ops::ShareDataKernel, ops::ShareDataKernel, + ops::ShareDataKernel, + ops::ShareDataKernel, + ops::ShareDataKernel, + ops::ShareDataKernel, + ops::ShareDataKernel) diff --git a/paddle/fluid/operators/share_data_op.cu b/paddle/fluid/operators/share_data_op.cu new file mode 100644 index 0000000000000..20cdaafa43de7 --- /dev/null +++ b/paddle/fluid/operators/share_data_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/share_data_op.h" + +REGISTER_OP_CUDA_KERNEL( + share_data, paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel, + paddle::operators::ShareDataKernel); diff --git a/paddle/fluid/operators/share_data_op.h b/paddle/fluid/operators/share_data_op.h new file mode 100644 index 0000000000000..d876b4fabd5c0 --- /dev/null +++ b/paddle/fluid/operators/share_data_op.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ShareDataKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *in_var = ctx.InputVar("X"); + auto *out_var = ctx.OutputVar("Out"); + if (in_var->IsType()) { + const auto &origin_tensor = in_var->Get(); + auto *detach_tensor = out_var->GetMutable(); + detach_tensor->ShareDataWith(origin_tensor); + } else { + const auto &origin_selected_rows = in_var->Get(); + auto *detach_selected_rows = + out_var->GetMutable(); + detach_selected_rows->mutable_value()->ShareDataWith( + origin_selected_rows.value()); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 80c27c585d810..9e06f107a3758 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -947,35 +947,43 @@ def __init__(self, self._stop_gradient = stop_gradient self.is_data = is_data - @fake_interface_only def detach(self): """ - **Notes**: - **This API is ONLY available in Dygraph mode** - Returns a new Variable, detached from the current graph. + It will share data with origin Variable and without tensor copy. + In addition, the detached Variable doesn't provide gradient propagation. Returns: ( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable. - Examples: .. code-block:: python - import paddle.fluid as fluid - from paddle.fluid.dygraph.base import to_variable - from paddle.fluid.dygraph import Linear - import numpy as np + import paddle - data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') - with fluid.dygraph.guard(): - linear = Linear(32, 64) - data = to_variable(data) - x = linear(data) - y = x.detach() + paddle.enable_static() + + # create a static Variable + x = paddle.static.data(name='x', shape=[3, 2, 1]) + # create a detached Variable + y = x.detach() """ - pass + + assert self.type == core.VarDesc.VarType.SELECTED_ROWS or \ + self.type == core.VarDesc.VarType.LOD_TENSOR, \ + "only support a variable with SELECTED_ROWS or LOD_TENSOR to be detached" + + output = self.block.create_var( + name=unique_name.generate_with_ignorable_key("detach_" + self.name), + dtype=self.dtype, + type=self.type, + persistable=self.persistable, + stop_gradient=True) + + self.block.append_op( + type='share_data', inputs={'X': [self]}, outputs={'Out': [output]}) + return output @fake_interface_only def numpy(self): @@ -1810,6 +1818,35 @@ def set_value(self, value, scope=None): t.set(value, place) + def size(self): + """ + Returns the number of elements for current Variable, which is a int64 Variable with shape [1] + + Returns: + Variable: the number of elements for current Variable + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + # create a static Variable + x = paddle.static.data(name='x', shape=[3, 2, 1]) + + # get the number of elements of the Variable + y = x.size() + """ + + output = self.block.create_var( + name=unique_name.generate_with_ignorable_key(self.name + "_size"), + dtype=core.VarDesc.VarType.INT64) + + self.block.append_op( + type='size', inputs={'Input': [self]}, outputs={'Out': [output]}) + return output + def get_all_op_protos(): """ diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 9433e0e5ee0e5..feb723d9c8b43 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -45,6 +45,7 @@ "__rpow__": "A **= B", "__floordiv__": "A //B", "__mod__": "A % B", + "__matmul__": "A @ B", "__eq__": "A == B", "__ne__": "A != B", "__lt__": "A < B", @@ -195,6 +196,28 @@ def _scalar_op_(var, scale, bias): def _neg_(var): return _scalar_op_(var, -1.0, 0.0) + @property + def _ndim_(self): + """ + Returns the dimension of current Variable + + Returns: + the dimension + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + # create a static Variable + x = paddle.static.data(name='x', shape=[3, 2, 1]) + # print the dimension of the Variable + print(x.ndim) + """ + return len(self.shape) + def _scalar_add_(var, value): return _scalar_op_(var, 1.0, value) @@ -228,9 +251,9 @@ def __impl__(self, other_var): other_var = float(other_var) # division is a special case # NOTE(chenweihang): because we cast tensor to float32 instead float64, - # the division result can only guarantee the numerical accuracy of 6 digits - # after the decimal point. The result of numpy calculation is of float64 type, - # so the calculation result here and the calculation result of numpy are + # the division result can only guarantee the numerical accuracy of 6 digits + # after the decimal point. The result of numpy calculation is of float64 type, + # so the calculation result here and the calculation result of numpy are # different after 6 decimal point. If necessary, we can also use float64 here. # torch's behavior here is consistent with ours if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: @@ -238,7 +261,7 @@ def __impl__(self, other_var): # here use `scale` replace `elementwise` to get better performance # but only +, -, * can use this method # NOTE(chentianyu03): / can not use `scale` method,because the result of - # `scale` method (self*(1/other_var)) do not exactly equal with the result + # `scale` method (self*(1/other_var)) do not exactly equal with the result # of `elementwise_div` method. if scalar_method is not None: return scalar_method(self, other_var) @@ -321,6 +344,9 @@ def __impl__(self, other_var): # b=-a ('__neg__', _neg_), ('astype', astype), + ('dim', lambda x: len(x.shape)), + ('ndimension', lambda x: len(x.shape)), + ('ndim', _ndim_), ('__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_)), # a+b == b+a. Do not need to reverse explicitly @@ -347,6 +373,8 @@ def __impl__(self, other_var): 'elementwise_floordiv', False, None)), ('__mod__', _binary_creator_('__mod__', 'elementwise_mod', False, None)), + ('__matmul__', _binary_creator_('__matmul__', "matmul_v2", False, + None)), # for logical compare ('__eq__', _binary_creator_('__eq__', 'equal', False, None)), ('__ne__', _binary_creator_('__ne__', 'not_equal', False, None)), diff --git a/python/paddle/fluid/tests/unittests/test_detach.py b/python/paddle/fluid/tests/unittests/test_detach.py index 5a31418205c32..8d19a1d3f65cd 100644 --- a/python/paddle/fluid/tests/unittests/test_detach.py +++ b/python/paddle/fluid/tests/unittests/test_detach.py @@ -149,12 +149,6 @@ def test_NoDetachSingle_DetachMulti(self): array_detach_multi = self.detach_multi() assert np.array_equal(array_no_detach_single, array_detach_multi) - def test_detach_exception(self): - x = fluid.layers.data(name="a", shape=[3, 4], dtype='float32') - y = fluid.layers.fc(input=x, size=10, bias_attr=True) - with self.assertRaises(AssertionError): - y_detach = y.detach() - class TestInplace(unittest.TestCase): def test_forward_version(self): diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index b2afda9ed3f25..cef5adbc5d3e3 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -271,7 +271,6 @@ def test_astype(self): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np.astype('float32'), b_np)) - @prog_scope() def test_bitwise_and(self): x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") @@ -336,6 +335,28 @@ def test_bitwise_not(self): fetch_list=[z]) self.assertTrue(np.array_equal(out[0], out_np)) + @prog_scope() + def test_ndim(self): + a = paddle.static.data(name="a", shape=[10, 1]) + self.assertEqual(a.dim(), 2) + self.assertEqual(a.ndimension(), 2) + self.assertEqual(a.ndim, 2) + + @prog_scope() + def test_matmul(self): + a = paddle.static.data(name='a', shape=[2, 3], dtype='float32') + b = paddle.static.data(name='b', shape=[3, 5], dtype='float32') + c = a @b # __matmul__ + a_np = numpy.random.uniform(-1, 1, size=[2, 3]).astype('float32') + b_np = numpy.random.uniform(-1, 1, size=[3, 5]).astype('float32') + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + c_np = exe.run(paddle.static.default_main_program(), + feed={"a": a_np, + "b": b_np}, + fetch_list=[c]) + self.assertTrue(numpy.allclose(a_np @b_np, c_np)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_share_data_op.py b/python/paddle/fluid/tests/unittests/test_share_data_op.py new file mode 100644 index 0000000000000..1e6f0ef693c3d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_share_data_op.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from paddle.fluid.op import Operator + + +class TestShareDataOp(OpTest): + def setUp(self): + self.op_type = "share_data" + input = np.random.rand(2, 3, 5).astype("float32") + self.inputs = {'X': input} + self.outputs = {'Out': input} + + def test_check_output(self): + self.check_output() + + +class TestShareDataOpOnDifferentPlaces(unittest.TestCase): + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + def check_with_tensor(self, place): + scope = core.Scope() + np_array = np.random.rand(2, 3, 5).astype("float32") + + # initialize input and output variable + x = scope.var('X').get_tensor() + x.set(np_array, place) + out = scope.var("Out").get_tensor() + + op = Operator("share_data", X="X", Out="Out") + op.run(scope, place) + self.assertTrue(np.allclose(np_array, out)) + + def check_with_selected_rows(self, place): + scope = core.Scope() + x_rows = [0, 1, 5, 4, 19] + x_height = 20 + row_numel = 2 + np_array = np.ones((len(x_rows), row_numel)).astype("float32") + + # initialize input variable + x = scope.var('X').get_selected_rows() + x.set_rows(x_rows) + x.set_height(x_height) + x_tensor = x.get_tensor() + x_tensor.set(np_array, place) + + # initialize the Out variable + out = scope.var("Out").get_selected_rows() + out_tensor = out.get_tensor() + + op = Operator("share_data", X="X", Out="Out") + op.run(scope, place) + + out_height = out.height() + out_rows = out.rows() + self.assertTrue(np.allclose(np_array, out_tensor)) + self.assertEqual(x_height, out_height) + self.assertEqual(x_rows, out_rows) + + def test_check_output(self): + for place in self.get_places(): + self.check_with_selected_rows(place) + self.check_with_tensor(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index c1956545f55ad..a998d58fdbc60 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -305,7 +305,6 @@ def test_fake_interface_only_api(self): b = default_main_program().current_block() var = b.create_var(dtype="float64", lod_level=0) with fluid.dygraph.guard(): - self.assertRaises(AssertionError, var.detach) self.assertRaises(AssertionError, var.numpy) self.assertRaises(AssertionError, var.backward) self.assertRaises(AssertionError, var.gradient) @@ -345,6 +344,60 @@ def _test(): self.assertRaises(Exception, _test) + def test_size(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(np.random.rand(2, 3, 4).astype("float32")) + exe = paddle.static.Executor(fluid.CPUPlace()) + exe.run(paddle.static.default_startup_program()) + + output = exe.run(prog, fetch_list=[x.size()]) + self.assertEqual(output[0], [24]) + + def test_detach(self): + b = default_main_program().current_block() + x = b.create_var(shape=[2, 3, 5], dtype="float64", lod_level=0) + detach_x = x.detach() + self.assertEqual(x.persistable, detach_x.persistable) + self.assertEqual(x.shape, detach_x.shape) + self.assertEqual(x.dtype, detach_x.dtype) + self.assertEqual(x.type, detach_x.type) + self.assertTrue(detach_x.stop_gradient) + + xx = b.create_var(name='xx', type=core.VarDesc.VarType.STEP_SCOPES) + self.assertRaises(AssertionError, xx.detach) + + startup = paddle.static.Program() + main = paddle.static.Program() + scope = fluid.core.Scope() + with paddle.static.scope_guard(scope): + with paddle.static.program_guard(main, startup): + x = paddle.static.data( + name='x', shape=[3, 2, 1], dtype='float32') + x.persistable = True + feed_data = np.ones(shape=[3, 2, 1], dtype=np.float32) + detach_x = x.detach() + exe = paddle.static.Executor(paddle.CPUPlace()) + exe.run(startup) + result = exe.run(main, + feed={'x': feed_data}, + fetch_list=[x, detach_x]) + self.assertTrue((result[1] == feed_data).all()) + self.assertTrue((result[0] == result[1]).all()) + + modified_value = np.zeros(shape=[3, 2, 1], dtype=np.float32) + detach_x.set_value(modified_value, scope) + result = exe.run(main, fetch_list=[x, detach_x]) + self.assertTrue((result[1] == modified_value).all()) + self.assertTrue((result[0] == result[1]).all()) + + modified_value = np.random.uniform( + -1, 1, size=[3, 2, 1]).astype('float32') + x.set_value(modified_value, scope) + result = exe.run(main, fetch_list=[x, detach_x]) + self.assertTrue((result[1] == modified_value).all()) + self.assertTrue((result[0] == result[1]).all()) + class TestVariableSlice(unittest.TestCase): def _test_item_none(self, place): diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index ccc6fcefdfb9f..7b38f39976097 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -468,6 +468,7 @@ 'test_sign_op', 'test_similarity_focus_op', 'test_size_op', + 'test_share_data_op', 'test_smooth_l1_loss', 'test_smooth_l1_loss_op', 'test_softmax_with_cross_entropy_op',