Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,19 @@ struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
}
};

/*! \brief Attributes used in softplus operators */
struct SoftplusAttrs : public tvm::AttrsNode<SoftplusAttrs> {
double beta;
double threshold;

TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") {
TVM_ATTR_FIELD(beta).describe(
"Scaling factor controlling the sharpness of the Softplus transition.");
TVM_ATTR_FIELD(threshold).describe(
"Value determining when to use linear approximation for numerical stability.");
}
};

/*! \brief Attributes used in batch_norm operator */
struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
int axis;
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,32 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
return wrap_nested(_op.nn.softmax(x._expr, axis), name)


def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = "softplus"):
r"""Softplus activation function.

.. math::
\text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})

Parameters
----------
data : relax.Expr
The input data.

beta : float, optional
Controls the smoothness of the transition. Default is 1.0.

threshold : float, optional
The value beyond which the function is approximated as linear
to avoid numerical instability. Default is 20.0.

Returns
-------
result : relax.Expr
The computed result.
"""
return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name)


def tanh(x: Tensor, name: str = "tanh") -> Tensor:
r"""Applies the hyperbolic tangent function.

Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def _softmax(self, node: fx.Node) -> relax.Var:
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))

def _softplus(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
beta = node.args[1] if len(node.args) > 1 else node.kwargs.get("beta", 1.0)
threshold = node.args[2] if len(node.args) > 2 else node.kwargs.get("threshold", 20.0)
return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold))

def _softshrink(self, node: fx.Node) -> relax.Var:
"""
Applies the Softshrink activation function in Relax.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def create_convert_map(
"sin.default": self._unary_op(relax.op.sin),
"sinh.default": self._unary_op(relax.op.sinh),
"softmax.int": self._softmax,
"softplus.default": self._softplus,
"softshrink.default": self._softshrink,
"sqrt.default": self._unary_op(relax.op.sqrt),
"square.default": self._unary_op(relax.op.square),
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def _leakyrelu_module(self, node: fx.Node) -> relax.Var:
alpha = module.negative_slope
return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha))

def _softplus_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
beta = module.beta
threshold = module.threshold
return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold))

def _log2(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.block_builder.emit(
Expand Down Expand Up @@ -653,6 +660,7 @@ def create_convert_map(
nn.SELU: self._unary_op(relax.op.nn.selu),
nn.SiLU: self._unary_op(relax.op.nn.silu),
nn.Softmax: self._softmax_module,
nn.Softplus: self._softplus_module,
nn.Tanh: self._unary_op(relax.op.tanh),
# neural network
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
Expand Down Expand Up @@ -717,6 +725,7 @@ def create_convert_map(
"sin": self._unary_op(relax.op.sin),
"sinh": self._unary_op(relax.op.sinh),
"softmax": self._softmax,
"softplus": self._softplus,
"sqrt": self._unary_op(relax.op.sqrt),
"square": self._unary_op(relax.op.square),
"tan": self._unary_op(relax.op.tan),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@
selu,
silu,
softmax,
softplus,
)
25 changes: 25 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,31 @@ def softmax(data: Expr, axis: int = -1) -> Expr:
return _ffi_api.softmax(data, axis) # type: ignore


def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr:
r"""Softplus activation function.

.. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})

Parameters
----------
data : relax.Expr
The input data.

beta : float, optional
Controls the smoothness of the transition. Default is 1.0.

threshold : float, optional
The value beyond which the function is approximated as linear
to avoid numerical instability. Default is 20.0.

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.softplus(data, beta, threshold)


def log_softmax(data: Expr, axis: int = -1) -> Expr:
r"""Computes log softmax.

Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,16 @@ def te_silu(x: te.Tensor):
return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu")


@register_legalize("relax.nn.softplus")
def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.nn.softplus,
call.args[0],
call.attrs.beta,
call.attrs.threshold,
)


@register_legalize("relax.nn.softmax")
def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis)
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/topi/nn/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ def _compute(*indices):
return te.compute(x.shape, _compute)


@tvm.te.tag_scope(tag=tag.ELEMWISE)
def softplus(x, beta=1.0, threshold=20.0):
"""Compute Softplus activation for input x with numerical stability.

Parameters
----------
x : tvm.te.Tensor
Input tensor.

beta : float, optional
The scaling factor β in the Softplus formula (default is 1.0).

threshold : float, optional
The threshold value for numerical stability (default is 20.0).

Returns
-------
y : tvm.te.Tensor
The result.
"""

def _compute(*indices):
value = x(*indices)
b = tvm.tir.const(beta, value.dtype)
t = tvm.tir.const(threshold, value.dtype)

return tvm.tir.Select(
b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value))
)

return te.compute(x.shape, _compute)


@tvm.te.tag_scope(tag=tag.BROADCAST)
def prelu(x, slope, axis=1):
"""PReLU.
Expand Down
21 changes: 21 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,27 @@ TVM_REGISTER_OP("relax.nn.leakyrelu")
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.softplus */
TVM_REGISTER_NODE_TYPE(SoftplusAttrs);

Expr softplus(Expr data, double beta, double threshold) {
auto attrs = make_object<SoftplusAttrs>();
attrs->beta = beta;
attrs->threshold = threshold;
static const Op& op = Op::Get("relax.nn.softplus");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus);

TVM_REGISTER_OP("relax.nn.softplus")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attrs_type<SoftplusAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.softmax */
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);

Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ Expr silu(Expr data);
/*! \brief Softmax function. */
Expr softmax(Expr data, int axis);

/*! \brief Softplus function. */
Expr softplus(Expr data, double beta, double threshold);

/*! \brief LogSoftmax function. */
Expr log_softmax(Expr data, int axis);

Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def main(
# leakyrelu
test_leakyrelu()

# softplus
test_softplus()

# log2
class Log2(Module):
def forward(self, x):
Expand Down Expand Up @@ -655,6 +658,43 @@ def main(
verify_model(Hardtanh2(), example_args, {}, expected1)


def test_softplus():
import torch
from torch.nn import Module

torch.set_grad_enabled(False)

class Softplus0(torch.nn.Module):
def __init__(self):
super().__init__()
self.softplus = torch.nn.Softplus(1.0, 20.0)

def forward(self, x):
return self.softplus(x)

class Softplus1(Module):
def forward(self, input):
return torch.nn.functional.softplus(input, 1.0, 20.0)

@tvm.script.ir_module
class expected:
@R.function
def main(
x: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(
x, beta=1.0, threshold=20.0
)
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
verify_model(Softplus0(), example_args, {}, expected)
verify_model(Softplus1(), example_args, {}, expected)


def test_leakyrelu():
import torch
from torch.nn import Module
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,43 @@ def main(
verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2)


@tvm.testing.requires_gpu
def test_softplus():
import torch
from torch.nn import Module

torch.set_grad_enabled(False)

class Softplus0(torch.nn.Module):
def __init__(self):
super().__init__()
self.softplus = torch.nn.Softplus(1.0, 20.0)

def forward(self, x):
return self.softplus(x)

class Softplus1(Module):
def forward(self, input):
return torch.nn.functional.softplus(input, 1.0, 20.0)

@tvm.script.ir_module
class expected:
@R.function
def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus(
inp_0, beta=1.0, threshold=20.0
)
gv: R.Tensor((10, 10), dtype="float32") = lv
R.output(gv)
return gv

input_info = [([10, 10], "float32")]
verify_model(Softplus0(), input_info, {}, expected)
verify_model(Softplus1(), input_info, {}, expected)


@tvm.testing.requires_gpu
def test_leakyrelu():
import torch
Expand Down Expand Up @@ -2226,6 +2263,9 @@ def main(
# leaky_relu
test_leakyrelu()

# softplus
test_softplus()

# log2
class Log2(Module):
def forward(self, x):
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor):
tanh_out = op.tanh(x)
exp_out = op.exp(x)
negative_out = op.negative(x)
softplus_out = op.softplus(x, beta=1.0, threshold=20.0)
softmax_out = op.softmax(x, axis=2)
rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
Expand All @@ -413,6 +414,9 @@ def test(
tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x)
exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x)
negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x)
softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus(
x, beta=1.0, threshold=20.0
)
softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2)
rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05
Expand Down
5 changes: 5 additions & 0 deletions tests/python/relax/test_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3), "float32"))
assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu")
assert relax.op.nn.leakyrelu(x).op == Op.get("relax.nn.leakyrelu")
assert relax.op.nn.softplus(x).op == Op.get("relax.nn.softplus")
assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu")
assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu")
assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
Expand Down Expand Up @@ -75,6 +76,8 @@ def test_linear_unit_infer_struct_info():
_check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype=""))
_check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 4), dtype=""))
_check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2, 3), "float32"))
_check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 4), dtype=""))


def test_linear_unit_infer_struct_info_shape_symbolic():
Expand All @@ -87,6 +90,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic():
_check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, n), "float32"))
_check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, n), "float32"))


def test_linear_unit_infer_struct_info_shape_var():
Expand All @@ -99,6 +103,7 @@ def test_linear_unit_infer_struct_info_shape_var():
_check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32"))
_check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32"))
_check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, "float32"))
_check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, "float32"))


def test_linear_unit_infer_struct_info_more_input_dtype():
Expand Down