Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
module FFTW

using LinearAlgebra, Reexport, Preferences
@reexport using AbstractFFTs
using Base.Threads

import AbstractFFTs: Plan, ScaledPlan, AbstractFFTBackend,
fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
import AbstractFFTs: AbstractFFTs, Plan, ScaledPlan, AbstractFFTBackend,
fftshift, ifftshift,
rfft_output_size, brfft_output_size,
plan_inv, normalization
Expand All @@ -19,7 +15,15 @@ include("providers.jl")
export FFTWBackend
struct FFTWBackend <: AbstractFFTBackend end
backend() = FFTWBackend()
activate!() = AbstractFFTs.set_active_backend!(FFTW)

for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft, :brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, args...; kws...) = AbstractFFTs.$f(FFTWBackend(), x, args...; kws...)
$pf(x::AbstractArray, args...; kws...) = AbstractFFTs.$pf(FFTWBackend(), x, args...; kws...)
end
end


function fftw_init_check()
# If someone is trying to set the provider via the old environment variable, warn them that they
Expand Down Expand Up @@ -65,7 +69,6 @@ end
const libfftw3 = FakeLazyLibrary(:libfftw3_no_init, fftw_init_check, C_NULL)
const libfftw3f = FakeLazyLibrary(:libfftw3f_no_init, fftw_init_check, C_NULL)
function __init__()
activate!()
end
else
@static if fftw_provider == "fftw"
Expand All @@ -81,7 +84,6 @@ elseif fftw_provider == "mkl"
end
function __init__()
fftw_init_check()
activate!()
end
end

Expand Down
28 changes: 14 additions & 14 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,36 +771,36 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
plan_f! = Symbol("plan_",f,"!")
idirection = -direction
@eval begin
function $plan_f(b::FFTWBackend, X::StridedArray{T,N}, region;
function AbstractFFTs.$plan_f(b::FFTWBackend, X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f(b, X, region; flags = flags, timelimit = timelimit)
AbstractFFTs.$plan_f(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,false,N}(X, fakesimilar(flags, X, T),
region, flags, timelimit)
end

function $plan_f!(::FFTWBackend, X::StridedArray{T,N}, region;
function AbstractFFTs.$plan_f!(::FFTWBackend, X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f!(b, X, region; flags = flags, timelimit = timelimit)
AbstractFFTs.$plan_f!(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
end
$plan_f(b::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(b, X, ntuple(identity, ndims(X)); kws...)
$plan_f!(b, ::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(b, X, ntuple(identity, ndims(X)); kws...)
AbstractFFTs.$plan_f(b::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
AbstractFFTs.$plan_f(b, X, ntuple(identity, ndims(X)); kws...)
AbstractFFTs.$plan_f!(b, ::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
AbstractFFTs.$plan_f!(b, X, ntuple(identity, ndims(X)); kws...)

function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
Expand Down Expand Up @@ -843,13 +843,13 @@ end
for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
# Note: use $FORWARD and $BACKWARD below because of issue #9775
@eval begin
function plan_rfft(b::FFTWBackend, X::StridedArray{$Tr,N}, region;
function AbstractFFTs.plan_rfft(b::FFTWBackend, X::StridedArray{$Tr,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_rfft(b, X, region; flags = flags, timelimit = timelimit)
AbstractFFTs.plan_rfft(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
Expand All @@ -858,13 +858,13 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit)
end

function plan_brfft(::FFTWBackend, X::StridedArray{$Tc,N}, d::Integer, region;
function AbstractFFTs.plan_brfft(::FFTWBackend, X::StridedArray{$Tc,N}, d::Integer, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_brfft(b, X, d, region; flags = flags, timelimit = timelimit)
AbstractFFTs.plan_brfft(b, X, d, region; flags = flags, timelimit = timelimit)
end
return plan
end
Expand All @@ -884,8 +884,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
end
end

plan_rfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_rfft(b, X,ntuple(identity, ndims(X));kws...)
plan_brfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_brfft(b, X,ntuple(identity, ndims(X));kws...)
AbstractFFTs.plan_rfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=AbstractFFTs.plan_rfft(b, X,ntuple(identity, ndims(X));kws...)
AbstractFFTs.plan_brfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=AbstractFFTs.plan_brfft(b, X,ntuple(identity, ndims(X));kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
num_threads::Union{Nothing, Integer} = nothing) where N
Expand Down