Skip to content

Commit 46aaf61

Browse files
authored
[BugFix] add the default value for DFT in ONNX frontend (#16659)
1 parent 880af30 commit 46aaf61

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4809,9 +4809,9 @@ class DFT(OnnxOpConverter):
48094809
@classmethod
48104810
def _impl_v17(cls, inputs, attr, params):
48114811
# ************************* Read attrs *************************
4812-
axis = attr.get("axis")
4813-
inverse = attr.get("inverse")
4814-
onesided = attr.get("onesided")
4812+
axis = attr.get("axis", 1)
4813+
inverse = attr.get("inverse", 0)
4814+
onesided = attr.get("onesided", 0)
48154815

48164816
# ************************* Read inputs ************************
48174817
input_tensor = inputs[0]

tests/python/frontend/onnx/test_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8238,7 +8238,7 @@ def verify_dft(
82388238
D = 7
82398239

82408240
for axis in list(range(1, n)) + [-2]:
8241-
for inverse, onesided in [(0, 0), (0, 1), (1, 0)]:
8241+
for inverse, onesided in [(0, 0), (0, 1), (1, 0), (None, None)]:
82428242
for n_fft in [D, D - 1, D + 1]:
82438243
for c in [1, 2]:
82448244
input_shape = [batch_size] + n * [D] + [c]

0 commit comments

Comments
 (0)