From 29b2698b65cbeee815311510853011a1ba9c615a Mon Sep 17 00:00:00 2001 From: esythan Date: Tue, 14 Dec 2021 10:08:03 +0000 Subject: [PATCH 1/6] add nansum api --- python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_nansum_api.py | 101 ++++++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/math.py | 63 +++++++++++ 4 files changed, 166 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_nansum_api.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 661cd495b53e8..3eac2a2bb8e40 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -189,6 +189,7 @@ from .tensor.math import square # noqa: F401 from .tensor.math import stanh # noqa: F401 from .tensor.math import sum # noqa: F401 +from .tensor.math import nansum # noqa: F401 from .tensor.math import tanh # noqa: F401 from .tensor.math import tanh_ # noqa: F401 from .tensor.math import add_n # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_nansum_api.py b/python/paddle/fluid/tests/unittests/test_nansum_api.py new file mode 100644 index 0000000000000..a9fc285d2d9d0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nansum_api.py @@ -0,0 +1,101 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + + +class API_Test_Nansum(unittest.TestCase): + def test_static_graph(self): + paddle.enable_static() + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + input = fluid.data(name='input', dtype='float32', shape=[2, 4]) + out1 = paddle.nansum(input) + out2 = paddle.nansum(input, axis=0) + out3 = paddle.nansum(input, axis=-1) + out4 = paddle.nansum(input, axis=1, keepdim=True) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + x = np.array([[float('nan'), 3, 5, 9], + [1, 2, float('-nan'), 7]]).astype(np.float32) + res = exe.run(train_program, + feed={'input': x}, + fetch_list=[out1, out2, out3, out4]) + + out1_np = np.array(res[0]) + out2_np = np.array(res[1]) + out3_np = np.array(res[2]) + out4_np = np.array(res[3]) + out1_ref = np.array([27]).astype(np.float32) + out2_ref = np.array([1, 5, 5, 16]).astype(np.float32) + out3_ref = np.array([17, 10]).astype(np.float32) + out4_ref = np.array([[17], [10]]).astype(np.float32) + + self.assertTrue( + (out1_np == out1_ref).all(), + msg='nansum output is wrong, out =' + str(out1_np)) + self.assertTrue( + (out2_np == out2_ref).all(), + msg='nansum output is wrong, out =' + str(out2_np)) + self.assertTrue( + (out3_np == out3_ref).all(), + msg='nansum output is wrong, out =' + str(out3_np)) + self.assertTrue( + (out4_np == out4_ref).all(), + msg='nansum output is wrong, out =' + str(out4_np)) + + def test_error_api(self): + paddle.enable_static() + + ## input dtype error + def run1(): + input = fluid.data(name='input', dtype='float16', shape=[2, 3]) + output = paddle.nansum(input) + + self.assertRaises(TypeError, run1) + + ## axis type error + def run2(): + input = fluid.data(name='input', dtype='float16', shape=[2, 3]) + output = paddle.nansum(input, axis=1.2) + + self.assertRaises(TypeError, run2) + + def test_dygraph(self): + x = np.array([[float('nan'), 3, 5, 9], + [1, 2, float('-nan'), 7]]).astype(np.float32) + with fluid.dygraph.guard(): + inputs = fluid.dygraph.to_variable(x) + out = paddle.nansum(inputs) + out_ref = np.array([27]).astype(np.float32) + + self.assertTrue( + (out.numpy() == out_ref).all(), + msg='nansum output is wrong, out =' + str(out.numpy())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 793fdb89d0677..1d75daa07cf9b 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -147,6 +147,7 @@ from .math import square # noqa: F401 from .math import stanh # noqa: F401 from .math import sum # noqa: F401 +from .math import nansum # noqa: F401 from .math import tanh # noqa: F401 from .math import tanh_ # noqa: F401 from .math import add_n # noqa: F401 diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 36d61fa08546b..5aa89e1549143 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -780,6 +780,69 @@ def get_dtype(x, dtype): return out +def nansum(x, axis=None, dtype=None, keepdim=False, name=None): + """ + Computes the sum of tensor elements over the given dimension, treating Not a Numbers (NaNs) as zero. + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. + axis (int|list|tuple, optional): The dimensions along which the nansum is performed. If + :attr:`None`, nansum all elements of :attr:`x` and return a + Tensor with a single element, otherwise must be in the + range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`, + the dimension to reduce is :math:`rank + axis[i]`. + dtype (str, optional): The dtype of output Tensor. The default value is None, the dtype + of output is the same as input Tensor `x`. + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result Tensor will have one fewer dimension + than the :attr:`x` unless :attr:`keepdim` is true, default + value is False. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor: Results of summation operation on the specified axis of input Tensor `x`, + if `x.dtype='int32'`, it's data type is `'int64'`, + otherwise it's data type is the same as `x`. + + Raises: + TypeError: The type of :attr:`axis` must be int, list or tuple. + + Examples: + .. code-block:: python + + import paddle + + # x is a Tensor with following elements: + # [[nan, 0.3, 0.5, 0.9] + # [0.1, 0.2, -nan, 0.7]] + # Each example is followed by the corresponding output tensor. + x = paddle.to_tensor([[nan, 0.3, 0.5, 0.9], + [0.1, 0.2, -nan, 0.7]]) + out1 = paddle.nansum(x) # [2.7] + out2 = paddle.nansum(x, axis=0) # [0.1, 0.5, 0.5, 1.6] + out3 = paddle.nansum(x, axis=-1) # [1.7, 1.0] + out4 = paddle.nansum(x, axis=1, keepdim=True) # [[1.7], [1.0]] + + # y is a Tensor with shape [2, 2, 2] and elements as below: + # [[[1, nan], [3, 4]], + # [[5, 6], [-nan, 8]]] + # Each example is followed by the corresponding output tensor. + y = paddle.to_tensor([[[1, nan], [3, 4]], + [[5, 6], [-nan, 8]]]) + out5 = paddle.nansum(y, axis=[1, 2]) # [8, 19] + out6 = paddle.nansum(y, axis=[0, 1]) # [9, 18] + """ + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'nansum') + check_type(axis, 'axis', (int, list, tuple, type(None)), 'nansum') + + helper = LayerHelper('nansum', **locals()) + zero_tensor = paddle.zeros_like(x) + tmp_tensor = paddle.where(isnan(x), zero_tensor, x) + return sum(tmp_tensor, axis, dtype, keepdim, name) + + @templatedoc(op_type="sum") def add_n(inputs, name=None): """ From 111d41855d6db5ba6b1f0d10640097e460d8829c Mon Sep 17 00:00:00 2001 From: esythan Date: Thu, 16 Dec 2021 07:15:06 +0000 Subject: [PATCH 2/6] delete layerhelper --- python/paddle/tensor/math.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 638a5a248005c..5c2033543814c 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -837,7 +837,6 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): x, 'x', ['float32', 'float64', 'int32', 'int64'], 'nansum') check_type(axis, 'axis', (int, list, tuple, type(None)), 'nansum') - helper = LayerHelper('nansum', **locals()) zero_tensor = paddle.zeros_like(x) tmp_tensor = paddle.where(isnan(x), zero_tensor, x) return sum(tmp_tensor, axis, dtype, keepdim, name) From 281769e346c80cb16c9d890a08e3c101994a9261 Mon Sep 17 00:00:00 2001 From: esythan Date: Mon, 20 Dec 2021 12:48:43 +0000 Subject: [PATCH 3/6] add nansum to all and tensor_method_func --- python/paddle/__init__.py | 1 + python/paddle/tensor/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c1036eda1522a..84f6efbd0ab76 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -500,6 +500,7 @@ 'ones', 'not_equal', 'sum', + 'nansum', 'tile', 'greater_equal', 'isfinite', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ab6fd2a088140..9a4adac2901ee 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -299,6 +299,7 @@ 'square', 'stanh', 'sum', + 'nansum', 'tanh', 'tanh_', 'add_n', From 3d074e5d3e7a96ea7186e814cfd0a59869e3cad4 Mon Sep 17 00:00:00 2001 From: esythan Date: Tue, 21 Dec 2021 03:23:48 +0000 Subject: [PATCH 4/6] update doc --- python/paddle/tensor/math.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 5c2033543814c..7d89ae9d176a0 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -812,13 +812,15 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): .. code-block:: python import paddle + import numpy as np # x is a Tensor with following elements: # [[nan, 0.3, 0.5, 0.9] # [0.1, 0.2, -nan, 0.7]] # Each example is followed by the corresponding output tensor. - x = paddle.to_tensor([[nan, 0.3, 0.5, 0.9], - [0.1, 0.2, -nan, 0.7]]) + x = np.array([[float('nan'), 0.3, 0.5, 0.9], + [0.1, 0.2, float('-nan'), 0.7]]).astype(np.float32) + x = paddle.to_tensor(x) out1 = paddle.nansum(x) # [2.7] out2 = paddle.nansum(x, axis=0) # [0.1, 0.5, 0.5, 1.6] out3 = paddle.nansum(x, axis=-1) # [1.7, 1.0] @@ -828,8 +830,9 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): # [[[1, nan], [3, 4]], # [[5, 6], [-nan, 8]]] # Each example is followed by the corresponding output tensor. - y = paddle.to_tensor([[[1, nan], [3, 4]], - [[5, 6], [-nan, 8]]]) + y = np.array([[[1, float('nan')], [3, 4]], + [[5, 6], [float('-nan'), 8]]]) + y = paddle.to_tensor(y) out5 = paddle.nansum(y, axis=[1, 2]) # [8, 19] out6 = paddle.nansum(y, axis=[0, 1]) # [9, 18] """ From 8702dc08e349a482a3f4383f74b1aa66e42184ee Mon Sep 17 00:00:00 2001 From: esythan Date: Wed, 22 Dec 2021 03:04:50 +0000 Subject: [PATCH 5/6] update doc --- python/paddle/tensor/math.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 7d89ae9d176a0..522bb067b9593 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -782,7 +782,7 @@ def get_dtype(x, dtype): def nansum(x, axis=None, dtype=None, keepdim=False, name=None): """ - Computes the sum of tensor elements over the given dimension, treating Not a Numbers (NaNs) as zero. + Computes the sum of tensor elements over the given axis, treating Not a Numbers (NaNs) as zero. Args: x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. @@ -802,8 +802,6 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): Returns: Tensor: Results of summation operation on the specified axis of input Tensor `x`, - if `x.dtype='int32'`, it's data type is `'int64'`, - otherwise it's data type is the same as `x`. Raises: TypeError: The type of :attr:`axis` must be int, list or tuple. From e50f7243f0c256f105cb9a928a1d020f426ec7f6 Mon Sep 17 00:00:00 2001 From: esythan Date: Wed, 22 Dec 2021 08:38:47 +0000 Subject: [PATCH 6/6] update doc --- python/paddle/tensor/math.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 522bb067b9593..626a29def1379 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -803,9 +803,6 @@ def nansum(x, axis=None, dtype=None, keepdim=False, name=None): Returns: Tensor: Results of summation operation on the specified axis of input Tensor `x`, - Raises: - TypeError: The type of :attr:`axis` must be int, list or tuple. - Examples: .. code-block:: python