Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,100 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
return wrap_nested(_op.nn.softmax(x._expr, axis), name)


def tanh(x: Tensor, name: str = "tanh") -> Tensor:
r"""Applies the hyperbolic tangent function.

.. math::
\text{Tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.

Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.tanh(x._expr), name)


def exp(x: Tensor, name: str = "exp") -> Tensor:
r"""Applies the exponential function.

.. math::
\text{Exp}(x) = e^x

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.

Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.exp(x._expr), name)


def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor:
"""Permutes the dimensions of the input tensor.

Parameters
----------
x : Tensor
The input data to the operator.

axes : Optional[List[int]]
The target axes order.

name : str
Name hint.

Returns
-------
result : Tensor
The transposed result.
"""

return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)


def negative(x: Tensor, name: str = "neg") -> Tensor:
"""Numerical negative of the input tensor.

Parameters
----------
x : Tensor
The input data to the operator.

name : str
Name hint.

Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.negative(x._expr), name)


def layer_norm(
x: Tensor,
normalized_shape: Union[int, List[int]],
Expand Down
6 changes: 6 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor):
silu_out = op.silu(x)
gelu_out = op.gelu(x)
sigmoid_out = op.sigmoid(x)
tanh_out = op.tanh(x)
exp_out = op.exp(x)
negative_out = op.negative(x)
softmax_out = op.softmax(x, axis=2)
rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
Expand All @@ -357,6 +360,9 @@ def test(
silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x)
tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x)
exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x)
negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x)
softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)
rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
Expand Down