diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd71a08..25d89b6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,6 @@ jobs: using Pkg Pkg.add(PackageSpec(path=pwd(), subdir="AbstractNFFTs")) Pkg.add(PackageSpec(path=pwd(), subdir="NFFTTools")) - Pkg.add(PackageSpec(path=pwd(), subdir="CuNFFT")) - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: @@ -70,7 +69,6 @@ jobs: using Pkg Pkg.add(PackageSpec(path=pwd(), subdir="AbstractNFFTs")) Pkg.add(PackageSpec(path=pwd(), subdir="NFFTTools")) - Pkg.add(PackageSpec(path=pwd(), subdir="CuNFFT")) - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-docdeploy@latest env: diff --git a/AbstractNFFTs/Project.toml b/AbstractNFFTs/Project.toml index 010f333..e754d5d 100644 --- a/AbstractNFFTs/Project.toml +++ b/AbstractNFFTs/Project.toml @@ -1,12 +1,13 @@ name = "AbstractNFFTs" uuid = "7f219486-4aa7-41d6-80a7-e08ef20ceed7" author = ["Tobias Knopp "] -version = "0.8.3" +version = "0.9.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -16,4 +17,5 @@ AbstractNFFTsChainRulesCoreExt = "ChainRulesCore" [compat] julia = "1.6" -ChainRulesCore = "1" \ No newline at end of file +ChainRulesCore = "1" +ScopedValues = "1" \ No newline at end of file diff --git a/AbstractNFFTs/src/AbstractNFFTs.jl b/AbstractNFFTs/src/AbstractNFFTs.jl index d1da6b0..7b41638 100644 --- a/AbstractNFFTs/src/AbstractNFFTs.jl +++ b/AbstractNFFTs/src/AbstractNFFTs.jl @@ -3,7 +3,16 @@ module AbstractNFFTs using LinearAlgebra using Printf +# Remove this difference once 1.11 or higher becomes lower bound +if VERSION >= v"1.11" + using Base.ScopedValues +else + using ScopedValues +end + + # interface +export AbstractNFFTBackend, nfft_backend, with export AbstractFTPlan, AbstractRealFTPlan, AbstractComplexFTPlan, AbstractNFFTPlan, AbstractNFCTPlan, AbstractNFSTPlan, AbstractNNFFTPlan, plan_nfft, plan_nfct, plan_nfst, mul!, size_in, size_out, nodes! @@ -25,6 +34,7 @@ include("misc.jl") include("interface.jl") include("derived.jl") + @static if !isdefined(Base, :get_extension) import Requires end diff --git a/AbstractNFFTs/src/derived.jl b/AbstractNFFTs/src/derived.jl index 0501843..635d682 100644 --- a/AbstractNFFTs/src/derived.jl +++ b/AbstractNFFTs/src/derived.jl @@ -10,64 +10,189 @@ planfunc = Symbol("plan_"*"$op") # The following automatically call the plan_* version for type Array -$(planfunc)(k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} = - $(planfunc)(Array, k, N, args...; kargs...) +$(planfunc)(b::AbstractNFFTBackend, k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} = + $(planfunc)(b, Array, k, N, args...; kargs...) -$(planfunc)(k::AbstractArray, y::AbstractArray, args...; kargs...) = - $(planfunc)(Array, k, y, args...; kargs...) +$(planfunc)(b::AbstractNFFTBackend, k::AbstractArray, y::AbstractArray, args...; kargs...) = + $(planfunc)(b, Array, k, y, args...; kargs...) + +$(planfunc)(k::AbstractArray, args...; kargs...) = $(planfunc)(active_backend(), k, args...; kargs...) # The follow convert 1D parameters into the format required by the plan -$(planfunc)(Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) = - $(planfunc)(Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...) +$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) = + $(planfunc)(b, Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...) + +$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} = + $(planfunc)(b, Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...) -$(planfunc)(Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} = - $(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...) +$(planfunc)(b::AbstractNFFTBackend, Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} = + $(planfunc)(b, Q, collect(k), N, rest...; kwargs...) -$(planfunc)(Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} = - $(planfunc)(Q, collect(k), N, rest...; kwargs...) +$(planfunc)(Q::Type, args...; kwargs...) = $(planfunc)(active_backend(), Q, args...; kwargs...) +$(planfunc)(::Missing, args...; kwargs...) = no_backend_error() end end ## NNFFT constructor -plan_nnfft(Q::Type, k::AbstractVector, y::AbstractVector, rest...; kwargs...) = - plan_nnfft(Q, collect(reshape(k,1,length(k))), collect(reshape(y,1,length(k))), rest...; kwargs...) - +plan_nnfft(Q::Type, args...; kwargs...) = plan_nnfft(active_backend(), Q, args...; kwargs...) +plan_nnfft(b::AbstractNFFTBackend, Q::Type, k::AbstractVector, y::AbstractVector, rest...; kwargs...) = + plan_nnfft(b, Q, collect(reshape(k,1,length(k))), collect(reshape(y,1,length(k))), rest...; kwargs...) +plan_nnfft(::Missing, args...; kwargs...) = no_backend_error() ############################################### # Allocating trafo functions with plan creation ############################################### +""" + nfft(k, f, rest...; kwargs...) + nfft(backend, k, f, rest...; kwargs...) + +calculates the nfft of the array `f` for the nodes contained in the matrix `k` +The output is a vector of length M=`size(nodes,2)`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +Backends can also be set with a scoped value overriding the current active backend within a scope: + +```julia +julia> NFFT.activate!() + +julia> nfft(k, f, rest...; kwargs...) # uses NFFT + +julia> with(nfft_backend => NonuniformFFTs.backend()) do + nfft(k, f, rest...; kwargs...) # uses NonuniformFFTs + end +``` +""" +nfft +""" + nfft_adjoint(k, N, fHat, rest...; kwargs...) + nfft_adjoint(backend, k, N, fHat, rest...; kwargs...) + +calculates the adjoint nfft of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +Backends can also be set with a scoped value overriding the current active backend within a scope: + +```julia +julia> NFFT.activate!() + +julia> nfft_adjoint(k, N, fHat, rest...; kwargs...) # uses NFFT + +julia> with(nfft_backend => NonuniformFFTs.backend()) do + nfft_adjoint(k, N, fHat, rest...; kwargs...) # uses NonuniformFFTs + end +``` +""" +nfft_adjoint +""" + nfft_transpose(k, N, fHat, rest...; kwargs...) + nfft_transpose(backend, k, N, fHat, rest...; kwargs...) + +calculates the transpose nfft of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +Backends can also be set with a scoped value overriding the current active backend within a scope: + +```julia +julia> NFFT.activate!() + +julia> nfft_transpose(k, N, fHat, rest...; kwargs...) # uses NFFT + +julia> with(nfft_backend => NonuniformFFTs.backend()) do + nfft_transpose(k, N, fHat, rest...; kwargs...) # uses NonuniformFFTs + end +``` +""" +nfft_transpose + +""" + nfct(k, f, rest...; kwargs...) + nfct(backend, k, f, rest...; kwargs...) + +calculates the nfct of the array `f` for the nodes contained in the matrix `k` +The output is a vector of length M=`size(nodes,2)`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfct +""" + nfct_adjoint(k, N, fHat, rest...; kwargs...) + nfct_adjoint(backend, k, N, fHat, rest...; kwargs...) + +calculates the adjoint nfct of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfct_adjoint +""" + nfct_transpose(k, N, fHat, rest...; kwargs...) + nfct_transpose(backend, k, N, fHat, rest...; kwargs...) + +calculates the transpose nfct of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfct_transpose + +""" + nfst(k, f, rest...; kwargs...) + nfst(backend, k, f, rest...; kwargs...) + +calculates the nfst of the array `f` for the nodes contained in the matrix `k` +The output is a vector of length M=`size(nodes,2)`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfst +""" + nfst_adjoint(k, N, fHat, rest...; kwargs...) + nfst_adjoint(backend, k, N, fHat, rest...; kwargs...) + +calculates the adjoint nfst of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfst_adjoint +""" + nfst_transpose(k, N, fHat, rest...; kwargs...) + nfst_transpose(backend, k, N, fHat, rest...; kwargs...) + +calculates the transpose nfst of the vector `fHat` for the nodes contained in the matrix `k`. +The output is an array of size `N`. + +Uses the active AbstractNFFTs `backend` if no `backend` argument is provided. Backends can be activated with `BackendModule.activate!()`. +""" +nfst_transpose + for (op,trans) in zip([:nfft, :nfct, :nfst], [:adjoint, :transpose, :transpose]) planfunc = Symbol("plan_$(op)") tfunc = Symbol("$(op)_$(trans)") @eval begin -# TODO fix comments (how?) -""" -nfft(k, f, rest...; kwargs...) - -calculates the nfft of the array `f` for the nodes contained in the matrix `k` -The output is a vector of length M=`size(nodes,2)` -""" -function $(op)(k, f::AbstractArray; kargs...) +$(op)(k, f::AbstractArray; kargs...) = $(op)(active_backend(), k, f::AbstractArray; kargs...) +function $(op)(b::AbstractNFFTBackend, k, f::AbstractArray; kargs...) p = $(planfunc)(k, size(f); kargs... ) return p * f end +$(op)(::Missing, k, f::AbstractArray; kargs...) = no_backend_error() -""" -nfft_adjoint(k, N, fHat, rest...; kwargs...) -calculates the adjoint nfft of the vector `fHat` for the nodes contained in the matrix `k`. -The output is an array of size `N` -""" -function $(tfunc)(k, N, fHat; kargs...) +$(tfunc)(k, N, fHat; kargs...) = $(tfunc)(active_backend(), k, N, fHat; kargs...) +function $(tfunc)(b::AbstractNFFTBackend, k, N, fHat; kargs...) p = $(planfunc)(k, N; kargs...) return $(trans)(p) * fHat end +$(tfunc)(::Missing, k, N, fHat; kargs...) = no_backend_error() + end end diff --git a/AbstractNFFTs/src/interface.jl b/AbstractNFFTs/src/interface.jl index 0c75c02..83299fa 100644 --- a/AbstractNFFTs/src/interface.jl +++ b/AbstractNFFTs/src/interface.jl @@ -1,3 +1,33 @@ +abstract type AbstractNFFTBackend end +struct BackendReference + ref::Ref{Union{Missing, AbstractNFFTBackend}} + BackendReference(val::Union{Missing, AbstractNFFTBackend}) = new(Ref{Union{Missing, AbstractNFFTBackend}}(val)) +end +Base.setindex!(ref::BackendReference, val::Union{Missing, AbstractNFFTBackend}) = ref.ref[] = val +Base.setindex!(ref::BackendReference, val::Module) = setindex!(ref, val.backend()) +Base.getindex(ref::BackendReference) = getindex(ref.ref)::Union{Missing, AbstractNFFTBackend} +Base.convert(::Type{BackendReference}, val::AbstractNFFTBackend) = BackendReference(val) +const nfft_backend = ScopedValue(BackendReference(missing)) + +""" + set_active_backend!(back::Union{Missing, Module, AbstractNFFTBackend}) + +Set the default NFFT plan backend. A module `back` must implement `back.backend()`. +""" +set_active_backend!(back::Module) = set_active_backend!(back.backend()) +function set_active_backend!(back::Union{Missing, AbstractNFFTBackend}) + nfft_backend[][] = back +end +active_backend() = nfft_backend[][] +function no_backend_error() + error( + """ + No default backend available! + Make sure to also "import/using" an NFFT backend such as NFFT or NonuniformFFTs. + """ + ) +end + """ AbstractFTPlan{T,D,R} diff --git a/NFFTTools/Project.toml b/NFFTTools/Project.toml index 54c3a59..5112c3e 100644 --- a/NFFTTools/Project.toml +++ b/NFFTTools/Project.toml @@ -12,5 +12,5 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" [compat] julia = "1.6" AbstractFFTs = "1.0" -AbstractNFFTs = "0.6, 0.7, 0.8" +AbstractNFFTs = "0.6, 0.7, 0.8, 0.9" FFTW = "1" diff --git a/Project.toml b/Project.toml index 422c0a7..d224a82 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NFFT" uuid = "efe261a4-0d2b-5849-be55-fc731d526b0d" authors = ["Tobias Knopp "] -version = "0.13.7" +version = "0.14" [deps] AbstractNFFTs = "7f219486-4aa7-41d6-80a7-e08ef20ceed7" @@ -19,7 +19,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Adapt = "3, 4" -AbstractNFFTs = "0.8" +AbstractNFFTs = "0.9" BasicInterpolators = "0.6.5, 0.7" DataFrames = "1.3.1, 1.4.1" FFTW = "1.5" diff --git a/docs/src/abstract.md b/docs/src/abstract.md index 38c55d3..94a9b28 100644 --- a/docs/src/abstract.md +++ b/docs/src/abstract.md @@ -13,14 +13,27 @@ An overview about the current packages and their dependencies is shown in the fo ## Implementations Currently, there are four implementations of the `AbstractNFFTs` interface: -1. **NFFT.jl**: This is the reference implementation running und the CPU. -2. **CuNFFT.jl**: An implementation running on graphics hardware of Nvidia exploiting CUDA.jl -3. **NFFT3.jl**: In the `Wrapper` directory of `NFFT.jl` there is a wrapper around the `NFFT3.jl` package following the `AbstractNFFTs` interface. `NFFT3.jl` is itself a wrapper around the high performance C library [NFFT3](http://www.nfft.org). -4. **FINUFFT.jl**: In the `Wrapper` directory of `NFFT.jl` there is a wrapper around the `FINUFFT.jl` package. `FINUFFT.jl` is itself a wrapper around the high performance C++ library [FINUFFT](https://finufft.readthedocs.io). +1. **NFFT.jl**: This is the reference implementation running on the CPU and with configurations on the GPU. +2. **NFFT3.jl**: In the `Wrapper` directory of `NFFT.jl` there is a wrapper around the `NFFT3.jl` package following the `AbstractNFFTs` interface. `NFFT3.jl` is itself a wrapper around the high performance C library [NFFT3](http://www.nfft.org). +3. **FINUFFT.jl**: In the `Wrapper` directory of `NFFT.jl` there is a wrapper around the `FINUFFT.jl` package. `FINUFFT.jl` is itself a wrapper around the high performance C++ library [FINUFFT](https://finufft.readthedocs.io). +4. **NonuniformFFTs.jl**: Pure Julia package written with generic and fast GPU kernels written with KernelAbstractions.jl. !!! note Right now one needs to install `NFFT.jl` and manually include the wrapper files. In the future we hope to integrate the wrappers in `NFFT3.jl` and `FINUFFT.jl` directly such that it is much more convenient to switch libraries. +It's possible to change between different implementation backends. Each backend has to implement a backend type, which by convention can be accessed via for example `NFFT.backend()`. There are several ways to activate a backend: +```julia +# Actively setting a backend: +AbstractNFFTs.set_active_backend!(NFFT.backend()) +# Activating a backend: +NFFT.activate!() +# and creating a new dynamic scope which uses a different backend: +with(nfft_backend => NonuniformFFTs.backend()) do + # Uses NonuniformFFTs as implementation backend +end +# It's also possible to directly pass backends to functions: +nfft(NonuniformFFTs.backend(), ...) +``` ## Interface @@ -30,14 +43,24 @@ Here * `D` is the size of the input vector * `R` is the size of the output vector. Usually this will be `R=1` unless a directional NFFT is implemented. -For instance the `CuNFFTPlan` is defined like this +For instance the `NFFTPlan` is defined like this ```julia -mutable struct CuNFFTPlan{T,D} <: AbstractNFFTPlan{T,D,1} +mutable struct NFFTPlan{T,D,R} <: AbstractNFFTPlan{T,D,R} ... end ``` -In addition to the plan, the following functions need to be implemented: +Furthermore, a package needs to implement its own backend type to dispatch on +```julia +struct MyBackend <: AbstractNFFTBackend +``` +and it should allow a user to activate the package, which by convention can be done with (unexported) functions: +```julia +activate!() = AbstractNFFTs.set_active_backend!(MyBackend()) +backend() = MyBackend() +``` + +In addition to the plan and backend, the following functions need to be implemented: ```julia size_out(p) size_out(p) @@ -99,20 +122,16 @@ All parameters are put into keyword arguments that have to match as well. We des Additionally, to the type-specific constructor one can provide the factory ``` -plan_nfft(Q::Type, k::Matrix{T}, N::NTuple{D,Int}; kargs...) where {D} +plan_nfft(b::MyBackend, Q::Type, k::Matrix{T}, N::NTuple{D,Int}; kargs...) where {D} ``` where `Q` is the Array type, e.g. `Array`. The reason to require the array type is, that this allows for GPU implementations, which would use for instance `CuArray` here. The package `AbstractNFFTs` provides a convenient constructor ``` -plan_nfft(k::Matrix{T}, N::NTuple{D,Int}; kargs...) where {D} +plan_nfft(b::MyBackend, k::Matrix{T}, N::NTuple{D,Int}; kargs...) where {D} ``` defaulting to the `Array` type. -!!! note - Different packages implementing `plan_nfft` will conflict if the same `Q` is implemented. In case of `NFFT.jl` and `CuNFFT.jl` there is no conflict since the array type is different. - - ## Derived Interface Based on the core low-level interface that an `AbstractNFFTPlan` needs to provide, the package diff --git a/docs/src/index.md b/docs/src/index.md index 0b08f73..c54a92b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,7 +5,7 @@ ## Introduction This package provides a Julia implementation of the Non-equidistant Fast Fourier Transform (NFFT). -For a detailed introduction into the NFFT and its application please have a look at the [software paper](https://arxiv.org/pdf/2208.00049.pdf) on the `NFFT.jl`. Further resources are [nfft.org](http://www.nfft.org) and [finufft.readthedocs.io](https://finufft.readthedocs.io). You +For a detailed introduction into the NFFT and its application please have a look at the [software paper](https://arxiv.org/pdf/2208.00049.pdf) on the `NFFT.jl`. Further resources are [nfft.org](http://www.nfft.org) and [finufft.readthedocs.io](https://finufft.readthedocs.io). The NFFT is a fast implementation of the Non-equidistant Discrete Fourier Transform (NDFT) that is basically a Discrete Fourier Transform (DFT) with non-equidistant sampling nodes in either Fourier or time/space domain. @@ -25,7 +25,7 @@ add NFFT ``` This will install the packages `NFFT.jl` and all its dependencies. Most importantly it installs the abstract interface package `AbstractNFFTs.jl`, which `NFFT.jl` implements. -Additional NFFT related tools can be obtained by adding the package `NFFTTools.jl`. If you need support for `CUDA` you also need to install the package `CuNFFT.jl`. +Additional NFFT related tools can be obtained by adding the package `NFFTTools.jl`. If you need support for `CUDA` or other GPU backends, you only need to install the respective GPU backend and a GPU compatible plan will be available via a package extension. In case you want to use an alternative NFFT implementation such as [NFFT3.jl](https://github.com/NFFT/NFFT3.jl) or [FINUFFT.jl](https://github.com/ludvigak/FINUFFT.jl) we provide wrapper types allowing to use them as `AbstractNFFTs` implementations. They can be used like this: @@ -34,7 +34,9 @@ julia> using AbstractNFFTs julia> include(joinpath(dirname(pathof(AbstractNFFTs)), "..", "..", "Wrappers", "FINUFFT.jl")) julia> include(joinpath(dirname(pathof(AbstractNFFTs)), "..", "..", "Wrappers", "NFFT3.jl")) ``` -This requires that you first `add` the package you want to use. +This requires that you first `add` the package you want to use. + +A related package is [NonuniformFFTs.jl](https://github.com/jipolanco/NonuniformFFTs.jl) which provides a pure Julia implementation using KernelAbstractions.jl. It also features an implementation of the `AbstractNFFTs.jl` interface. ## Guide diff --git a/ext/NFFTGPUArraysExt/implementation.jl b/ext/NFFTGPUArraysExt/implementation.jl index 4367929..4371298 100644 --- a/ext/NFFTGPUArraysExt/implementation.jl +++ b/ext/NFFTGPUArraysExt/implementation.jl @@ -16,7 +16,7 @@ mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI B::SM end -function AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...; +function AbstractNFFTs.plan_nfft(::NFFTBackend, arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...; timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D} t = @elapsed begin p = GPU_NFFTPlan(arr, k, N, rest...; kargs...) diff --git a/src/NFFT.jl b/src/NFFT.jl index 99ad9c6..f8c54a4 100644 --- a/src/NFFT.jl +++ b/src/NFFT.jl @@ -16,7 +16,7 @@ using BasicInterpolators @reexport using AbstractNFFTs -export NDFTPlan, NDCTPlan, NDSTPlan, NNDFTPlan, +export NFFTBackend, NDFTPlan, NDCTPlan, NDSTPlan, NNDFTPlan, NFFTPlan, NFFTParams @@ -37,13 +37,16 @@ include("precomputation.jl") ################# # factory methods ################# +struct NFFTBackend <: AbstractNFFTBackend end +activate!() = AbstractNFFTs.set_active_backend!(NFFT) +backend() = NFFTBackend() """ - plan_nfft(k::Matrix{T}, N::NTuple{D,Int}, rest...; kargs...) + NFFT.plan_nfft(k::Matrix{T}, N::NTuple{D,Int}, rest...; kargs...) compute a plan for the NFFT of a size-`N` array at the nodes contained in `k`. """ -function AbstractNFFTs.plan_nfft(::Type{<:Array}, k::Matrix{T}, N::NTuple{D,Int}, rest...; +function AbstractNFFTs.plan_nfft(::NFFTBackend, ::Type{<:Array}, k::Matrix{T}, N::NTuple{D,Int}, rest...; timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D} t = @elapsed begin p = NFFTPlan(k, N, rest...; kargs...) @@ -60,6 +63,7 @@ include("convolution.jl") function __init__() NFFT._use_threads[] = (Threads.nthreads() > 1) + activate!() end include("precompile.jl") diff --git a/src/precompile.jl b/src/precompile.jl index d760a62..fc796df 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -6,7 +6,7 @@ using PrecompileTools J, N = 8, 16 k = range(-0.4, stop=0.4, length=J) # nodes at which the NFFT is evaluated f = randn(ComplexF64, J) # data to be transformed - p = plan_nfft(k, N, reltol=1e-9) # create plan + p = plan_nfft(NFFTBackend(), k, N, reltol=1e-9) # create plan fHat = adjoint(p) * f # calculate adjoint NFFT y = p * fHat # calculate forward NFFT diff --git a/test/constructors.jl b/test/constructors.jl index 3007b16..e5a5782 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -1,5 +1,17 @@ @testset "Constructors" begin + @testset "Backend" begin + @test NFFT.backend() isa NFFTBackend + AbstractNFFTs.set_active_backend!(missing) + @test ismissing(AbstractNFFTs.active_backend()) + with(nfft_backend => NFFT.backend()) do + @test AbstractNFFTs.active_backend() isa NFFTBackend + end + @test ismissing(AbstractNFFTs.active_backend()) + NFFT.activate!() + @test AbstractNFFTs.active_backend() isa NFFTBackend + end + @test_throws ArgumentError NFFTPlan(zeros(1,4), (2,2)) p = NFFTPlan(zeros(2,4), (2,2))