Skip to content

Commit ff1a4fb

Browse files
Jian ShengBoblest Sebastian (ETAS-DEV/XPC-Fe1)
authored andcommitted
[Frontend][PyTorch] Add: Relay stft operator (apache#11190)
* Add: Relay stft operator * fix doc * address PR comments * address addtional comments
1 parent ddd4b4b commit ff1a4fb

File tree

13 files changed

+701
-7
lines changed

13 files changed

+701
-7
lines changed

include/tvm/relay/attrs/transform.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,28 @@ struct EinsumAttrs : public tvm::AttrsNode<EinsumAttrs> {
536536
}
537537
}; // struct EinsumAttrs
538538

539+
/*! \brief Attributes used in stft operator */
540+
struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
541+
int n_fft;
542+
int hop_length;
543+
int win_length;
544+
bool normalized;
545+
bool onesided;
546+
547+
TVM_DECLARE_ATTRS(StftAttrs, "relay.attrs.StftAttrs") {
548+
TVM_ATTR_FIELD(n_fft).set_default(-1).describe("The size of Fourier transform");
549+
TVM_ATTR_FIELD(hop_length)
550+
.set_default(-1)
551+
.describe("The distance between neighboring sliding window frames");
552+
TVM_ATTR_FIELD(win_length).set_default(-1).describe("The size of window frame and STFT filter");
553+
TVM_ATTR_FIELD(normalized)
554+
.set_default(false)
555+
.describe("Whether to return the normalized STFT results");
556+
TVM_ATTR_FIELD(onesided).set_default(true).describe(
557+
"Whether to return onesided result or fill with conjugate symmetry");
558+
}
559+
}; // struct StftAttrs
560+
539561
} // namespace relay
540562
} // namespace tvm
541563
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_

python/tvm/relay/frontend/pytorch.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def min_max_common(self, name_elemwise, name_reduce, inputs, input_types):
277277
if len(inputs) == 1:
278278
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
279279
return get_relay_op(name_reduce)(data[0])
280-
elif len(inputs) >= 2 and isinstance(inputs[1], int):
280+
elif len(inputs) >= 2 and isinstance(inputs[1], (list, int)):
281281
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
282282
dim = inputs[1]
283283
keepdims = inputs[2] if len(inputs) > 2 else False
@@ -2188,6 +2188,17 @@ def deform_conv2d(self, inputs, input_types):
21882188

21892189
return _op.nn.bias_add(conv_out, bias)
21902190

2191+
def stft(self, inputs, input_types):
2192+
data = inputs[0]
2193+
n_fft = inputs[1]
2194+
hop_length = inputs[2]
2195+
win_length = inputs[3]
2196+
window = inputs[4]
2197+
normalized = inputs[5]
2198+
onesided = inputs[6]
2199+
2200+
return _op.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
2201+
21912202
def unbind(self, inputs, input_types):
21922203
data = inputs[0]
21932204
axis = int(inputs[1])
@@ -2996,6 +3007,9 @@ def create_convert_map(self):
29963007
"aten::sub": self.sub,
29973008
"aten::max": self.max,
29983009
"aten::min": self.min,
3010+
"aten::amax": self.max,
3011+
"aten::amin": self.min,
3012+
"aten::stft": self.stft,
29993013
"aten::mul": self.make_elemwise("multiply"),
30003014
"aten::pow": self.make_elemwise("power"),
30013015
"aten::arange": self.arange,

python/tvm/relay/op/_transform.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,50 @@ def compute_reshape(attrs, inputs, output_type):
140140

141141
_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy)
142142

143+
# stft
144+
@_reg.register_compute("stft")
145+
def compute_stft(attrs, inputs, output_type):
146+
"""Compute definition of stft"""
147+
return topi.stft(
148+
inputs[0],
149+
attrs.n_fft,
150+
attrs.hop_length,
151+
attrs.win_length,
152+
attrs.window,
153+
attrs.normalized,
154+
attrs.onesided,
155+
output_type.shape,
156+
)
157+
158+
159+
_reg.register_strategy("stft", strategy.stft_strategy)
160+
161+
162+
@script
163+
def _stft_shape_func(data, n_fft, hop_length, onesided):
164+
output_shape = output_tensor((4,), "int64")
165+
output_shape[0] = int64(data.shape[0])
166+
if onesided:
167+
output_shape[1] = int64(int64(n_fft) // int64(2)) + int64(1)
168+
else:
169+
output_shape[1] = int64(n_fft)
170+
output_shape[2] = int64(int64(data.shape[1] - n_fft) // int64(hop_length)) + int64(1)
171+
output_shape[3] = int64(2)
172+
return output_shape
173+
174+
175+
@_reg.register_shape_func("stft", True)
176+
def stft_shape_func(attrs, inputs, _):
177+
"""
178+
Shape func for stft.
179+
"""
180+
return [
181+
_stft_shape_func(
182+
inputs[0], convert(attrs.n_fft), convert(attrs.hop_length), convert(attrs.onesided)
183+
)
184+
]
185+
186+
143187
# scatter_add
144188
@_reg.register_compute("scatter_add")
145189
def compute_scatter_add(attrs, inputs, output_type):

python/tvm/relay/op/strategy/cuda.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,14 @@ def einsum_strategy_cuda(attrs, inputs, out_type, target):
13131313
name="einsum.cuda",
13141314
)
13151315
return strategy
1316+
1317+
1318+
@stft_strategy.register(["cuda", "gpu"])
1319+
def stft_strategy_cuda(attrs, inputs, out_type, target):
1320+
strategy = _op.OpStrategy()
1321+
strategy.add_implementation(
1322+
wrap_compute_stft(topi.cuda.stft),
1323+
wrap_topi_schedule(topi.generic.schedule_extern),
1324+
name="stft.cuda",
1325+
)
1326+
return strategy

python/tvm/relay/op/strategy/generic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,39 @@ def _compute_sparse_reshape(attrs, inputs, output_type):
13751375
return _compute_sparse_reshape
13761376

13771377

1378+
# stft
1379+
@override_native_generic_func("stft_strategy")
1380+
def stft_strategy(attrs, outs, out_type, target):
1381+
"""stft generic strategy"""
1382+
strategy = _op.OpStrategy()
1383+
strategy.add_implementation(
1384+
wrap_compute_stft(topi.stft),
1385+
wrap_topi_schedule(topi.generic.schedule_extern),
1386+
name="stft.generic",
1387+
)
1388+
return strategy
1389+
1390+
1391+
def wrap_compute_stft(topi_compute):
1392+
"""Wrap stft compute"""
1393+
1394+
def _compute_stft(attrs, inputs, output_type):
1395+
return [
1396+
topi_compute(
1397+
inputs[0],
1398+
attrs.n_fft,
1399+
attrs.hop_length,
1400+
attrs.win_length,
1401+
inputs[1],
1402+
attrs.normalized,
1403+
attrs.onesided,
1404+
output_type.shape,
1405+
)
1406+
]
1407+
1408+
return _compute_stft
1409+
1410+
13781411
# roi_pool
13791412
@generic_func
13801413
def schedule_roi_pool(attrs, outs, target):

python/tvm/relay/op/transform.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,3 +1829,63 @@ def invert_permutation(data):
18291829
relay.invert_permutation(data) = [2, 4, 3, 0, 1]
18301830
"""
18311831
return _make.invert_permutation(data)
1832+
1833+
1834+
def stft(
1835+
data, n_fft, hop_length=None, win_length=None, window=None, normalized=False, onesided=True
1836+
):
1837+
"""
1838+
The STFT computes the Fourier transform of short overlapping windows of the input.
1839+
This gives frequency components of the signal as they change over time.
1840+
1841+
Parameters
1842+
----------
1843+
data : relay.Expr
1844+
Either a 1-D tensor or a 2-D batch tensor.
1845+
1846+
n_fft : int
1847+
The size of Fourier transform
1848+
1849+
hop_length : int, optional
1850+
The distance between neighboring sliding window frames. If is None,
1851+
it is treated as equal to floor(n_fft / 4).
1852+
1853+
win_length : int, optional
1854+
The size of window frame and STFT filter. If is None, it is treated as equal to n_fft.
1855+
1856+
window : relay.Expr, optional
1857+
A 1-D tensor window frame. If is None (default), it is treated as if
1858+
having 1 everywhere in the window.
1859+
1860+
normalized : bool, optional
1861+
Whether to return the normalized STFT results. Default value is False.
1862+
1863+
onesided : bool, optional
1864+
Whether to return onesided result or fill with conjugate symmetry. Default value is True.
1865+
1866+
Returns
1867+
-------
1868+
output : relay.Expr
1869+
Tensor containing the STFT result with shape [batch, N, T, 2], where N is the
1870+
number of frequencies where STFT is applied and T is the total number of frames used.
1871+
1872+
Examples
1873+
--------
1874+
.. code-block:: python
1875+
1876+
data = [1, 2, 3, 4, 5, 6]
1877+
window = [4, 3, 2]
1878+
[n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, False, True]
1879+
relay.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
1880+
-> [[[15.0000, 0.0000], [34.0000, 0.0000]], [[ 4.5000, 0.8660], [ 1.0000, -1.7321]]]
1881+
"""
1882+
if hop_length is None:
1883+
hop_length = n_fft // 4
1884+
1885+
if win_length is None:
1886+
win_length = n_fft
1887+
1888+
if window is None:
1889+
window = _make.ones([n_fft], "int32")
1890+
1891+
return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)

python/tvm/topi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .einsum import *
4747
from .unique import *
4848
from .searchsorted import *
49+
from .stft import *
4950
from . import generic
5051
from . import nn
5152
from . import x86

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@
6060
from .transform import *
6161
from .unique import *
6262
from .searchsorted import *
63+
from .stft import *

python/tvm/topi/cuda/stft.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks, unused-argument
18+
"""STFT operator"""
19+
from math import pi
20+
import tvm
21+
from tvm import te, tir
22+
from ..utils import ceil_div
23+
24+
25+
def _get_max_threads(batch_row):
26+
max_threads = tvm.target.Target.current(allow_none=False).max_num_threads
27+
return tir.min(batch_row, max_threads)
28+
29+
30+
def stft(
31+
data,
32+
n_fft,
33+
hop_length,
34+
win_length,
35+
window,
36+
normalized,
37+
onesided,
38+
output_shape,
39+
):
40+
"""
41+
The STFT computes the Fourier transform of short overlapping windows of the input.
42+
This gives frequency components of the signal as they change over time.
43+
Parameters
44+
----------
45+
data : relay.Expr
46+
Either a 1-D tensor or a 2-D batch tensor.
47+
n_fft : int
48+
The size of Fourier transform
49+
hop_length : int
50+
The distance between neighboring sliding window frames
51+
win_length : int
52+
The size of window frame and STFT filter
53+
window : relay.Expr
54+
A 1-D tensor window frame
55+
normalized : bool
56+
Whether to return the normalized STFT results
57+
onesided : bool
58+
Whether to return onesided result or fill with conjugate symmetry
59+
Returns
60+
-------
61+
output : relay.Expr
62+
Tensor containing the STFT result
63+
Examples
64+
--------
65+
.. code-block:: python
66+
67+
data = [1, 2, 3, 4, 5, 6]
68+
window = [4, 3, 2]
69+
[n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, False, True]
70+
relay.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)
71+
-> [[[15.0000, 0.0000], [34.0000, 0.0000]], [[ 4.5000, 0.8660], [ 1.0000, -1.7321]]]
72+
"""
73+
74+
def gen_ir(
75+
data_ptr,
76+
n_fft,
77+
hop_length,
78+
win_length,
79+
window_ptr,
80+
normalized,
81+
onesided,
82+
output_ptr,
83+
):
84+
ib = tir.ir_builder.create()
85+
data = ib.buffer_ptr(data_ptr)
86+
window = ib.buffer_ptr(window_ptr)
87+
output = ib.buffer_ptr(output_ptr)
88+
max_threads = _get_max_threads(output_ptr.shape[0] * output_ptr.shape[1])
89+
output_size = output_ptr.shape[0] * output_ptr.shape[1] * output_ptr.shape[2]
90+
with ib.new_scope():
91+
nthread_tx = max_threads
92+
nthread_bx = ceil_div(output_size, max_threads)
93+
tx = te.thread_axis("threadIdx.x")
94+
bx = te.thread_axis("blockIdx.x")
95+
ib.scope_attr(tx, "thread_extent", nthread_tx)
96+
ib.scope_attr(bx, "thread_extent", nthread_bx)
97+
tid = bx * max_threads + tx
98+
99+
with ib.if_scope(tid < output_size):
100+
matrix_size = output_ptr.shape[1] * output_ptr.shape[2]
101+
batch = tir.floordiv(tid, matrix_size)
102+
row = tir.floordiv(tir.indexmod(tid, matrix_size), output_ptr.shape[2])
103+
col = tir.indexmod(tir.indexmod(tid, matrix_size), output_ptr.shape[2])
104+
output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0)
105+
output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0)
106+
with ib.for_range(0, win_length) as wlen:
107+
output[batch, row, col, 0] += (
108+
window[wlen]
109+
* data[batch, col * hop_length + wlen]
110+
* tir.cos(2 * pi * row * wlen / win_length)
111+
)
112+
output[batch, row, col, 1] -= (
113+
window[wlen]
114+
* data[batch, col * hop_length + wlen]
115+
* tir.sin(2 * pi * row * wlen / win_length)
116+
)
117+
with ib.if_scope(normalized):
118+
output[batch, row, col, 0] /= tir.sqrt(tir.const(n_fft, "float32"))
119+
output[batch, row, col, 1] /= tir.sqrt(tir.const(n_fft, "float32"))
120+
121+
return ib.get()
122+
123+
output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf")
124+
125+
return te.extern(
126+
output_shape,
127+
[data, window],
128+
lambda ins, outs: gen_ir(
129+
ins[0], n_fft, hop_length, win_length, ins[1], normalized, onesided, outs[0]
130+
),
131+
dtype=[data.dtype],
132+
out_buffers=[output_buf],
133+
name="stft_cuda",
134+
tag="stft_cuda",
135+
)

0 commit comments

Comments
 (0)