Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@
args_mapper :
func : ArgSumMapper

- op : tanh
name : [paddle.tanh, paddle.Tensor.tanh, paddle.nn.functional.tanh]
args_alias:
use_default_mapping : True

- op : exp
name : [paddle.exp, paddle.Tensor.exp]
args_alias:
Expand Down
60 changes: 55 additions & 5 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,6 +2314,47 @@ def dot(
""",
)

add_doc_and_signature(
"tanh",
r"""
Tanh Activation Operator.
.. math::
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.
Args:
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
float16, uint8, int8, int16, int32, int64. Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor. Default: None.
Returns:
Output of Tanh operator, a Tensor with same data type and shape as input
(integer types are autocasted into float32).
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
>>> out = paddle.tanh(x)
>>> out
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
""",
"""
def tanh(
x: Tensor, *, out: Tensor | None = None, name: str | None = None,
) -> Tensor
""",
)

add_doc_and_signature(
"exp",
"""
Expand All @@ -2331,7 +2372,7 @@ def dot(
x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.
Returns:
Tensor. Output of Exp operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2371,7 +2412,7 @@ def exp(
x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.
Returns:
Tensor. Output of Expm1 operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2494,10 +2535,16 @@ def diagonal(
out.shape = [4]
out.data = [1., -1., 3., 1.]
.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.
Args:
x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128.
Alias: ``input``.
decimals(int): Rounded decimal place (default: 0).
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output tensor. Default: None.
Returns:
Tensor. Output of Round operator, a Tensor with shape same as input.
Expand Down Expand Up @@ -2529,12 +2576,15 @@ def round(
out = |x|
.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.
Args:
x (Tensor): The input Tensor with data type int32, int64, float16, float32, float64, complex64 and complex128.
Alias: ``input``.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Keyword args:
out (Tensor|None, optional): The output tensor.
out (Tensor|None, optional): The output tensor. Default: None.
Returns:
Tensor.A Tensor with the same data type and shape as :math:`x`.
Expand Down
55 changes: 1 addition & 54 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
sign,
sin,
sum,
tanh,
)
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc, dygraph_utils
Expand Down Expand Up @@ -4459,60 +4460,6 @@ def prod(
return out


def tanh(x: Tensor, name: str | None = None) -> Tensor:
r"""
Tanh Activation Operator.
.. math::
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
Args:
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
float16, uint8, int8, int16, int32, int64.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output of Tanh operator, a Tensor with same data type and shape as input
(integer types are autocasted into float32).
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
>>> out = paddle.tanh(x)
>>> out
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
"""
if in_dynamic_or_pir_mode():
return _C_ops.tanh(x)
else:
check_variable_and_dtype(
x,
'x',
[
'uint16',
'float16',
'float32',
'float64',
'uint8',
'int8',
'int16',
'int32',
'int64',
],
'tanh',
)
check_type(x, 'x', (Variable), 'tanh')
helper = LayerHelper('tanh', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out})
return out


@inplace_apis_in_dygraph_only
def tanh_(x: Tensor, name: str | None = None) -> Tensor:
r"""
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6214,6 +6214,7 @@ class TestActivationAPI_Compatibility(unittest.TestCase):
("paddle.exp", np.exp, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.expm1", np.expm1, {'min_val': -1.0, 'max_val': 1.0}),
("paddle.round", np.round, {'min_val': -5.0, 'max_val': 5.0}),
("paddle.tanh", np.tanh, {'min_val': -1.0, 'max_val': 1.0}),
]

def setUp(self):
Expand Down
6 changes: 3 additions & 3 deletions test/legacy_test/test_weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def bow_net(
bow = paddle.static.nn.sequence_lod.sequence_pool(
input=emb, pool_type='sum'
)
bow_tanh = paddle.tanh(bow)
fc_1 = paddle.static.nn.fc(x=bow_tanh, size=hid_dim, activation="tanh")
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim2, activation="tanh")
bow_silu = paddle.nn.functional.silu(bow)
fc_1 = paddle.static.nn.fc(x=bow_silu, size=hid_dim, activation="silu")
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim2, activation="silu")
prediction = paddle.static.nn.fc(
x=[fc_2], size=class_dim, activation="softmax"
)
Expand Down
12 changes: 6 additions & 6 deletions test/standalone_executor/test_standalone_custom_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def build_program():
matmul_out = data @ weight
bias = paddle.ones([1024, 2048], dtype='float32', name='bias')
add_out = paddle.add(matmul_out, bias, name='add_out')
# add_out -> [sub] -> sub_out -> [tanh] -> tanh_out
# add_out -> [sub] -> sub_out -> [silu] -> silu_out
sub_out = paddle.subtract(add_out, data, name='sub_out')
tanh_out = paddle.tanh(sub_out, name='tanh_out')
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')
bias_1 = paddle.add(bias, sub_out, name='bias_1')
out_before = paddle.tanh(bias_1, name='out_before')
out_last = paddle.subtract(tanh_out, data, name='out_last')
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
out_last = paddle.subtract(silu_out, data, name='out_last')
out_last2 = out_last @ weight

out = paddle.add(out_before, out_last2, name='out')
Expand All @@ -64,9 +64,9 @@ class TestManualEvent(unittest.TestCase):
| | | |
| elementwise_sub(s1) |
| | | |
| tanh(s1) elementwise_add(s1)
| silu(s1) elementwise_add(s1)
| | |
elementwise_sub(s1) tanh(s1)
elementwise_sub(s1) silu(s1)
| |
matmul_v2(s1) |
| | ---split prog----
Expand Down
4 changes: 2 additions & 2 deletions test/standalone_executor/test_standalone_custom_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class TestCustomStream(unittest.TestCase):
| | | |
| elementwise_sub(cpu) |
| | | |
| tanh(cpu) elementwise_add(s2)
| silu(cpu) elementwise_add(s2)
| | |
elementwise_sub(s1) tanh(s2)
elementwise_sub(s1) silu(s2)
| |
elementwise_add(s2)
|
Expand Down
8 changes: 4 additions & 4 deletions test/standalone_executor/test_standalone_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def build_program():
bias = paddle.ones([4, 64], dtype='float32', name='bias')
add_out = paddle.add(matmul_out, bias, name='add_out')

# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out
# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [silu] -> silu_out
with paddle.static.device_guard('cpu'):
sub_out = paddle.subtract(add_out, data, name='sub_out')
tanh_out = paddle.tanh(sub_out, name='tanh_out')
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')

with paddle.static.device_guard('gpu'):
bias_1 = paddle.add(bias, sub_out, name='bias_1')
out_before = paddle.tanh(bias_1, name='out_before')
out_last = paddle.subtract(tanh_out, data, name='out_last')
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
out_last = paddle.subtract(silu_out, data, name='out_last')

out = paddle.add(out_before, out_last, name='out')
mean = paddle.mean(out, name='mean_out')
Expand Down
Loading