Skip to content
Merged
145 changes: 102 additions & 43 deletions onnxruntime/core/providers/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ ONNX_CPU_OPERATOR_KERNEL(STFT, 17,
.TypeConstraint("T2", BuildKernelDefConstraints<int32_t, int64_t>()),
STFT);

static bool is_real_valued_signal(const onnxruntime::TensorShape& shape) {
return shape.NumDimensions() == 2 || shape[shape.NumDimensions() - 1] == 1;
}

static bool is_complex_valued_signal(const onnxruntime::TensorShape& shape) {
return shape.NumDimensions() > 2 && shape[shape.NumDimensions() - 1] == 2;
}

constexpr static bool is_power_of_2(size_t size) {
unsigned n_bits = 0;
while (size != 0) {
Expand Down Expand Up @@ -148,6 +140,24 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s
*(Y_data + i * Y_data_stride) = std::complex<T>(1, 0) * x * window_element;
}

// For IRFFT, fix up the negative frequencies using conjugate symmetry
if (is_onesided && inverse) {
size_t conjugate_end = (dft_length % 2 == 0) ? (number_of_samples - 1) : number_of_samples;
for (size_t k = 1; k < conjugate_end; k++) {
// Find position in bit-reversed array that represents natural index N-k
size_t nk = dft_length - k;
size_t pos_nk = bit_reverse(nk, significant_bits);

// Get the input value at position k and the window for reconstructed position N-k
auto x_k = *(X_data + k * X_stride);
auto window_nk = window_data ? *(window_data + nk) : 1;

// Apply conjugate symmetry to input, then apply the window for X[N-k]
// X[N-k] = conj(X[k]), then multiply by window[N-k]
*(Y_data + pos_nk * Y_data_stride) = std::conj(x_k) * window_nk;
}
}

// Run fft_radix2
unsigned current_significant_bits = 0;
for (size_t i = 2; i <= dft_length; i <<= 1) {
Expand Down Expand Up @@ -179,10 +189,17 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s
}

if (is_onesided) {
const size_t output_size = (dft_length >> 1) + 1;
auto destination = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < output_size; i++) {
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
if (inverse) {
auto destination = reinterpret_cast<T*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_length; i++) {
*(destination + Y_stride * i) = (*(Y_data + i * Y_data_stride)).real();
}
} else {
const size_t output_size = (dft_length >> 1) + 1;
auto destination = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < output_size; i++) {
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
}
}
}

Expand All @@ -202,7 +219,7 @@ T next_power_of_2(T in) {
template <typename T, typename U>
static Status dft_bluestein_z_chirp(
OpKernelContext* ctx, const Tensor* X, Tensor* Y, Tensor& b_fft, Tensor& chirp, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride,
int64_t axis, size_t dft_length, const Tensor* window, bool inverse, InlinedVector<std::complex<T>>& V,
int64_t axis, size_t dft_length, const Tensor* window, bool is_onesided, bool inverse, InlinedVector<std::complex<T>>& V,
InlinedVector<std::complex<T>>& temp_output) {
static constexpr T pi = static_cast<T>(M_PI);

Expand Down Expand Up @@ -255,7 +272,6 @@ static Status dft_bluestein_z_chirp(

// Get data
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->DataRaw())) + X_offset;
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
U* window_data = nullptr;
if (window) {
window_data = const_cast<U*>(reinterpret_cast<const U*>(window->DataRaw()));
Expand All @@ -281,6 +297,19 @@ static Status dft_bluestein_z_chirp(
a_n *= window_n;
a_n *= chirp_n;
}
if (inverse && is_onesided) {
// IRFFT: fill in negative frequencies using conjugate symmetry
// For the input X: X[N-k] = conj(X[k]) for k=1..floor((N-1)/2)
// We need to apply this BEFORE the chirp multiplication
// So: a[N-k] = conj(X[k]) * window[N-k] * chirp[N-k]
size_t conjugate_end = (N % 2 == 0) ? (number_of_samples - 1) : number_of_samples;
for (size_t k = 1; k < conjugate_end; k++) {
auto x_k = *(X_data + k * X_stride); // Original input at k
auto window_nk = window_data ? *(window_data + N - k) : 1;
std::complex<T>& chirp_nk = *(chirp_data + N - k);
*(a_data + N - k) = std::conj(x_k) * window_nk * chirp_nk;
}
}

// Forward FFT radix2 for the "a" signal
ORT_RETURN_IF_ERROR((fft_radix2<T, std::complex<T>>(ctx, &a, &a_fft, 0, 1, 0, 1, 1, M, nullptr,
Expand All @@ -298,17 +327,33 @@ static Status dft_bluestein_z_chirp(
const auto& Y_shape = Y->Shape();
size_t dft_output_size = static_cast<size_t>(Y_shape[onnxruntime::narrow<size_t>(axis)]);

for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
std::complex<T>& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
// forward fft. This reversal causes the output to be reversed as well.
// Therefore we undo the reversal when writing the output back out.
c_i = *(a_data + M - i);
if (inverse && is_onesided) {
// IRFFT: extract real part only
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
T& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
c_i = *(a_data + M - i);
}
out = (c_i * chirp_i * scale).real();
}
} else {
// Standard complex output
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->MutableDataRaw()) + Y_offset;
for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
std::complex<T>& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
// forward fft. This reversal causes the output to be reversed as well.
// Therefore we undo the reversal when writing the output back out.
c_i = *(a_data + M - i);
}
out = c_i * chirp_i * scale;
}
out = c_i * chirp_i * scale;
}
return Status::OK();
}
Expand All @@ -322,15 +367,9 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
const auto& X_shape = X->Shape();
const auto& Y_shape = Y->Shape();

auto batch_and_signal_rank = X->Shape().NumDimensions();
auto total_dfts = static_cast<size_t>(X->Shape().Size() / X->Shape()[onnxruntime::narrow<size_t>(axis)]);

auto is_input_real = X->Shape().NumDimensions() == 2 || X->Shape()[X->Shape().NumDimensions() - 1] == 1;
auto complex_input_factor = is_input_real ? 1 : 2;
if (X->Shape().NumDimensions() > 2) {
total_dfts /= onnxruntime::narrow<size_t>(X->Shape()[X->Shape().NumDimensions() - 1]);
batch_and_signal_rank -= 1;
}
auto batch_and_signal_rank = X->Shape().NumDimensions() - 1;
auto complex_input_factor = onnxruntime::narrow<size_t>(X->Shape()[X->Shape().NumDimensions() - 1]);
auto total_dfts = static_cast<size_t>(X->Shape().Size() / X->Shape()[onnxruntime::narrow<size_t>(axis)]) / complex_input_factor;

// Calculate x/y offsets/strides
for (size_t i = 0; i < total_dfts; i++) {
Expand All @@ -349,7 +388,8 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
}

size_t Y_offset = 0;
size_t Y_stride = onnxruntime::narrow<size_t>(Y_shape.SizeFromDimension(SafeInt<size_t>(axis) + 1) / 2);
size_t Y_last_dim_size = (inverse && is_onesided) ? 1 : 2;
size_t Y_stride = onnxruntime::narrow<size_t>(Y_shape.SizeFromDimension(SafeInt<size_t>(axis) + 1) / Y_last_dim_size);
cumulative_packed_stride = total_dfts;
temp = i;
for (size_t r = 0; r < batch_and_signal_rank; r++) {
Expand All @@ -359,15 +399,15 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X,
cumulative_packed_stride /= onnxruntime::narrow<size_t>(X_shape[r]);
auto index = temp / cumulative_packed_stride;
temp -= (index * cumulative_packed_stride);
Y_offset += index * SafeInt<size_t>(Y_shape.SizeFromDimension(r + 1)) / 2;
Y_offset += index * SafeInt<size_t>(Y_shape.SizeFromDimension(r + 1)) / Y_last_dim_size;
}

if (is_power_of_2(onnxruntime::narrow<size_t>(dft_length))) {
ORT_RETURN_IF_ERROR((fft_radix2<T, U>(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window,
is_onesided, inverse, V, temp_output)));
} else {
ORT_RETURN_IF_ERROR(
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window, inverse, V, temp_output)));
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, onnxruntime::narrow<size_t>(dft_length), window, is_onesided, inverse, V, temp_output)));
}
}

Expand All @@ -379,11 +419,26 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
const auto* X = ctx->Input<Tensor>(0);
const auto* dft_length = ctx->Input<Tensor>(1);
const auto& X_shape = X->Shape();
const auto is_real_valued = is_real_valued_signal(X_shape);
const auto is_complex_valued = is_complex_valued_signal(X_shape);
const auto is_real_valued = X_shape[X_shape.NumDimensions() - 1] == 1;
const auto is_complex_valued = X_shape[X_shape.NumDimensions() - 1] == 2;

axis = HandleNegativeAxis(axis, X_shape.NumDimensions());
Comment thread
simonbyrne marked this conversation as resolved.
ORT_RETURN_IF(axis == static_cast<int64_t>(X_shape.NumDimensions()) - 1,
"DFT axis must refer to a signal dimension, not the last (real/imag) dimension.");
Comment thread
simonbyrne marked this conversation as resolved.

Comment thread
simonbyrne marked this conversation as resolved.
// Validate input for IRFFT
if (inverse && is_onesided) {
ORT_RETURN_IF(!is_complex_valued,
"Inverse one-sided DFT (IRFFT) requires complex-valued input (last dimension must be 2)");
} else {
ORT_RETURN_IF(!(is_real_valued || is_complex_valued),
"Input layout not supported. The input tensor must have a last dimension of 1 (real) or 2 (complex).");
}

int64_t number_of_samples = static_cast<int64_t>(X_shape[onnxruntime::narrow<size_t>(axis)]);
if (inverse && is_onesided) {
number_of_samples = (number_of_samples - 1) << 1;
}
if (dft_length) {
Comment thread
simonbyrne marked this conversation as resolved.
const auto& dft_length_shape = dft_length->Shape();
ORT_RETURN_IF(!dft_length_shape.IsScalar(), "dft_length must be a scalar value.");
Comment thread
simonbyrne marked this conversation as resolved.
Expand All @@ -393,13 +448,17 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo

// Get the DFT output size. Onesided will return only the unique values!
// note: x >> 1 === std::floor(x / 2.f)
auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples;
// For IRFFT (inverse && onesided), output is full-size real signal
// For RFFT (!inverse && onesided), output is one-sided complex spectrum
auto dft_output_size = (is_onesided && !inverse) ? ((number_of_samples >> 1) + 1) : number_of_samples;

// Get output shape
auto Y_shape = onnxruntime::TensorShape(X_shape);
if (X_shape.NumDimensions() == 2) {
Y_shape = onnxruntime::TensorShape({X_shape[0], dft_output_size, 2});
if (inverse && is_onesided) {
// IRFFT: output is real-valued
Y_shape[Y_shape.NumDimensions() - 1] = 1; // Real output
} else {
// Complex output
Y_shape[Y_shape.NumDimensions() - 1] = 2;
}
Y_shape[onnxruntime::narrow<size_t>(axis)] = dft_output_size;
Expand Down Expand Up @@ -567,8 +626,8 @@ Status STFT::Compute(OpKernelContext* ctx) const {
// Get signal shape
const auto* signal = ctx->Input<Tensor>(0);
const auto& signal_shape = signal->Shape();
const auto is_real_valued = is_real_valued_signal(signal_shape);
const auto is_complex_valued = is_complex_valued_signal(signal_shape);
const auto is_real_valued = signal_shape.NumDimensions() == 2 || signal_shape[signal_shape.NumDimensions() - 1] == 1;
const auto is_complex_valued = signal_shape.NumDimensions() > 2 && signal_shape[signal_shape.NumDimensions() - 1] == 2;

// Get data type
auto data_type = signal->DataType();
Expand Down
Loading
Loading