Skip to content

Commit 0e2f869

Browse files
czh978caizihua
andauthored
logsoftmax reusing the softmax function (#11141)
Co-authored-by: caizihua <[email protected]>
1 parent 0705bd7 commit 0e2f869

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2412,30 +2412,18 @@ class LogSoftmax(OnnxOpConverter):
24122412
"""Operator converter for Softmax."""
24132413

24142414
@classmethod
2415-
def run_calculation(cls, x, axes):
2415+
def run_calculation(cls, inputs, attr, params, opset):
24162416
"""Run the calculation for Log Softmax calculation."""
2417-
m = _op.max(x, axes, keepdims=True)
2418-
e = _op.exp(x - m)
2419-
s = _op.sum(e, axes, keepdims=True)
2420-
return x - m - _op.log(s)
2417+
res = Softmax.get_converter(opset)(inputs, attr, params)
2418+
return _op.log(res)
24212419

24222420
@classmethod
24232421
def _impl_v1(cls, inputs, attr, params):
2424-
axis = attr.get("axis", 1)
2425-
ndim = len(infer_shape(inputs[0]))
2426-
if axis < 0:
2427-
axis += ndim
2428-
axes = list(range(axis, ndim))
2429-
return cls.run_calculation(inputs[0], axes)
2422+
return cls.run_calculation(inputs, attr, params, opset=1)
24302423

24312424
@classmethod
24322425
def _impl_v13(cls, inputs, attr, params):
2433-
axis = attr.get("axis", -1)
2434-
ndim = len(infer_shape(inputs[0]))
2435-
if axis < 0:
2436-
axis += ndim
2437-
axes = [axis]
2438-
return cls.run_calculation(inputs[0], axes)
2426+
return cls.run_calculation(inputs, attr, params, opset=13)
24392427

24402428

24412429
class Hardmax(OnnxOpConverter):
@@ -4852,7 +4840,8 @@ def _impl_v13(cls, inputs, attr, params):
48524840
weight_tensor = None
48534841

48544842
get_log_prob = attr["tvm_custom"]["num_outputs"] == 2
4855-
log_softmax_tensor = LogSoftmax.run_calculation(input_tensor, axes=[1])
4843+
log_softmax_attr = {"axis": 1}
4844+
log_softmax_tensor = LogSoftmax.get_converter(13)([input_tensor], log_softmax_attr, None)
48564845

48574846
loss, weight_total = NegativeLogLikelihoodLoss.run_calculation(
48584847
log_softmax_tensor,

0 commit comments

Comments
 (0)