Skip to content

Commit 859abf0

Browse files
committed
Add forward-mode rules
1 parent 06bef3a commit 859abf0

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

ext/AbstractFFTsEnzymeCoreExt.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,58 @@
11
module AbstractFFTsEnzymeCoreExt
22

3+
using AbstractFFTs
4+
using AbstractFFTs.LinearAlgebra
5+
using EnzymeCore
6+
using EnzymeCore.EnzymeRules
7+
8+
######################
9+
# Forward-mode rules #
10+
######################
11+
12+
const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}}
13+
14+
# since FFTs are linear, implement all forward-model rules generically at a low-level
15+
16+
function EnzymeRules.forward(
17+
func::Const{typeof(mul!)},
18+
RT::Type{<:Const},
19+
y::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
20+
p::Const{<:AbstractFFTs.Plan{T}},
21+
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
22+
) where {T}
23+
val = func.val(y.val, p.val, x.val)
24+
if x isa Duplicated && y isa Duplicated
25+
dval = func.val(y.dval, p.val, x.dval)
26+
elseif x isa Duplicated && y isa Duplicated
27+
dval = map(y.dval, x.dval) do dy, dx
28+
return func.val(dy, p.val, dx)
29+
end
30+
end
31+
return nothing
32+
end
33+
34+
function EnzymeRules.forward(
35+
func::Const{typeof(*)},
36+
RT::Type{
37+
<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}
38+
},
39+
p::Const{<:AbstractFFTs.Plan},
40+
x::DuplicatedOrBatchDuplicated{<:StridedArray},
41+
)
42+
RT <: Const && return func.val(p.val, x.val)
43+
if x isa Duplicated
44+
dval = func.val(p.val, x.dval)
45+
RT <: DuplicatedNoNeed && return dval
46+
val = func.val(p.val, x.val)
47+
RT <: Duplicated && return Duplicated(val, dval)
48+
else # x isa BatchDuplicated
49+
dval = map(x.dval) do dx
50+
return func.val(p.val, dx)
51+
end
52+
RT <: BatchDuplicatedNoNeed && return dval
53+
val = func.val(p.val, x.val)
54+
RT <: BatchDuplicated && return BatchDuplicated(val, dval)
55+
end
56+
end
57+
358
end # module

0 commit comments

Comments
 (0)