Skip to content

Commit f23d6b2

Browse files
authored
[Relay][Bugfix] fix the wrong implementation of Softplus in OneFlow (#15717)
* Update test_forward.py * fix a bug in softplus * Update oneflow.py
1 parent 7322769 commit f23d6b2

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

python/tvm/relay/frontend/oneflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,8 +1119,11 @@ class Softplus(OneFlowOpConverter):
11191119
def _impl_v1(cls, inputs, attrs, params):
11201120
data = inputs[0]
11211121
data_dtype = infer_type(data).checked_type.dtype
1122-
data = _op.exp(data) + _expr.const(1, dtype=data_dtype)
1123-
return _op.log(data)
1122+
beta = _expr.const(float(attrs.get("beta", 1.0)))
1123+
threshold = float(attrs.get("threshold", 20.0))
1124+
threshold_ = _op.full_like(data, fill_value=_expr.const(threshold))
1125+
softplus_value = _op.log(_op.exp(data * beta) + _expr.const(1.0, dtype=data_dtype)) / beta
1126+
return _op.where(_op.greater(data * beta, threshold_), data, softplus_value)
11241127

11251128

11261129
class Softsign(OneFlowOpConverter):

tests/python/frontend/oneflow/test_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def forward(self, x):
721721

722722
for device in ["llvm"]:
723723
verify_activation(model1, device=device)
724-
# verify_activation(model2, device=device) # NO PASS
724+
verify_activation(model2, device=device)
725725
verify_activation(model3, device=device)
726726
verify_activation(model4, device=device)
727727
verify_activation(model5, device=device)

0 commit comments

Comments
 (0)