|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks, unused-argument |
| 18 | +"""STFT operator""" |
| 19 | +from math import pi |
| 20 | +import tvm |
| 21 | +from tvm import te, tir |
| 22 | +from ..utils import ceil_div |
| 23 | + |
| 24 | + |
| 25 | +def _get_max_threads(batch_row): |
| 26 | + max_threads = tvm.target.Target.current(allow_none=False).max_num_threads |
| 27 | + return tir.min(batch_row, max_threads) |
| 28 | + |
| 29 | + |
| 30 | +def stft( |
| 31 | + data, |
| 32 | + n_fft, |
| 33 | + hop_length, |
| 34 | + win_length, |
| 35 | + window, |
| 36 | + normalized, |
| 37 | + onesided, |
| 38 | + output_shape, |
| 39 | +): |
| 40 | + """ |
| 41 | + The STFT computes the Fourier transform of short overlapping windows of the input. |
| 42 | + This gives frequency components of the signal as they change over time. |
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + data : relay.Expr |
| 46 | + Either a 1-D tensor or a 2-D batch tensor. |
| 47 | + n_fft : int |
| 48 | + The size of Fourier transform |
| 49 | + hop_length : int |
| 50 | + The distance between neighboring sliding window frames |
| 51 | + win_length : int |
| 52 | + The size of window frame and STFT filter |
| 53 | + window : relay.Expr |
| 54 | + A 1-D tensor window frame |
| 55 | + normalized : bool |
| 56 | + Whether to return the normalized STFT results |
| 57 | + onesided : bool |
| 58 | + Whether to return onesided result or fill with conjugate symmetry |
| 59 | + Returns |
| 60 | + ------- |
| 61 | + output : relay.Expr |
| 62 | + Tensor containing the STFT result |
| 63 | + Examples |
| 64 | + -------- |
| 65 | + .. code-block:: python |
| 66 | +
|
| 67 | + data = [1, 2, 3, 4, 5, 6] |
| 68 | + window = [4, 3, 2] |
| 69 | + [n_fft, hop_length, win_length, normalized, onesided] = [3, 3, 3, False, True] |
| 70 | + relay.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) |
| 71 | + -> [[[15.0000, 0.0000], [34.0000, 0.0000]], [[ 4.5000, 0.8660], [ 1.0000, -1.7321]]] |
| 72 | + """ |
| 73 | + |
| 74 | + def gen_ir( |
| 75 | + data_ptr, |
| 76 | + n_fft, |
| 77 | + hop_length, |
| 78 | + win_length, |
| 79 | + window_ptr, |
| 80 | + normalized, |
| 81 | + onesided, |
| 82 | + output_ptr, |
| 83 | + ): |
| 84 | + ib = tir.ir_builder.create() |
| 85 | + data = ib.buffer_ptr(data_ptr) |
| 86 | + window = ib.buffer_ptr(window_ptr) |
| 87 | + output = ib.buffer_ptr(output_ptr) |
| 88 | + max_threads = _get_max_threads(output_ptr.shape[0] * output_ptr.shape[1]) |
| 89 | + output_size = output_ptr.shape[0] * output_ptr.shape[1] * output_ptr.shape[2] |
| 90 | + with ib.new_scope(): |
| 91 | + nthread_tx = max_threads |
| 92 | + nthread_bx = ceil_div(output_size, max_threads) |
| 93 | + tx = te.thread_axis("threadIdx.x") |
| 94 | + bx = te.thread_axis("blockIdx.x") |
| 95 | + ib.scope_attr(tx, "thread_extent", nthread_tx) |
| 96 | + ib.scope_attr(bx, "thread_extent", nthread_bx) |
| 97 | + tid = bx * max_threads + tx |
| 98 | + |
| 99 | + with ib.if_scope(tid < output_size): |
| 100 | + matrix_size = output_ptr.shape[1] * output_ptr.shape[2] |
| 101 | + batch = tir.floordiv(tid, matrix_size) |
| 102 | + row = tir.floordiv(tir.indexmod(tid, matrix_size), output_ptr.shape[2]) |
| 103 | + col = tir.indexmod(tir.indexmod(tid, matrix_size), output_ptr.shape[2]) |
| 104 | + output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0) |
| 105 | + output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0) |
| 106 | + with ib.for_range(0, win_length) as wlen: |
| 107 | + output[batch, row, col, 0] += ( |
| 108 | + window[wlen] |
| 109 | + * data[batch, col * hop_length + wlen] |
| 110 | + * tir.cos(2 * pi * row * wlen / win_length) |
| 111 | + ) |
| 112 | + output[batch, row, col, 1] -= ( |
| 113 | + window[wlen] |
| 114 | + * data[batch, col * hop_length + wlen] |
| 115 | + * tir.sin(2 * pi * row * wlen / win_length) |
| 116 | + ) |
| 117 | + with ib.if_scope(normalized): |
| 118 | + output[batch, row, col, 0] /= tir.sqrt(tir.const(n_fft, "float32")) |
| 119 | + output[batch, row, col, 1] /= tir.sqrt(tir.const(n_fft, "float32")) |
| 120 | + |
| 121 | + return ib.get() |
| 122 | + |
| 123 | + output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf") |
| 124 | + |
| 125 | + return te.extern( |
| 126 | + output_shape, |
| 127 | + [data, window], |
| 128 | + lambda ins, outs: gen_ir( |
| 129 | + ins[0], n_fft, hop_length, win_length, ins[1], normalized, onesided, outs[0] |
| 130 | + ), |
| 131 | + dtype=[data.dtype], |
| 132 | + out_buffers=[output_buf], |
| 133 | + name="stft_cuda", |
| 134 | + tag="stft_cuda", |
| 135 | + ) |
0 commit comments