-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathzygote.jl
109 lines (103 loc) · 4.95 KB
/
zygote.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
using SciMLBase: issymbollike, sym_to_index, getobserved
Zygote.@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
# i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
i = nothing
if i === nothing
getter = getobserved(VA)
grz = Zygote.pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob, p = dp)
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
typeof(dprob), Nothing, Nothing, Nothing}(du, nothing,
nothing, nothing, nothing, dprob, nothing, nothing,
VA.dense, 0, nothing, VA.retcode)
(Δ′, nothing, nothing)
else
Δ′ = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
(Δ′, nothing, nothing)
end
end
VA[sym, j], ODESolution_getindex_pullback
end
@adjoint function Base.getindex(VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
if i === nothing
throw("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")
else
Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
end
VA[sym], ODESolution_getindex_pullback
end
# Zygote.@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
# function ODESolution_getindex_pullback(Δ)
# i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
# if i === nothing
#
# zerou = zero(VA.prob.u0)
# _Δ = @. ifelse(Δ == nothing,(zerou,),Δ)
#
# #return (DiffEqBase.build_solution(VA.prob,VA.alg,VA.t,_Δ), nothing, nothing)
# return (DiffEqBase.build_solution(VA.prob,VA.alg,VA.t,_Δ), nothing, nothing)
#
# getter = SciMLBase.getobserved(VA)
# # @show getter
# # getter = VA.prob.f.observed
# grz = Zygote.pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
# # @show grz
# du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
# dp = grz[3] # pullback for p
# # @show size(dp)
# # dp = dp == nothing ? zeros(eltype(eltype(VA.u)), length(VA.prob.p)) : dp
# dprob = remake(VA.prob, p = dp)
# T = eltype(eltype(VA.u))
# N = length(VA.prob.p)
# Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
# typeof(dprob), Nothing, Nothing, Nothing}(du, nothing,
# nothing, nothing, nothing, dprob, nothing, nothing,
# VA.dense, 0, nothing, VA.retcode)
# (Δ′, nothing, nothing)
# else
# Δ′ = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] : zero(VA.u[1]) for m in 1:length(VA.u)]
# (Δ′, nothing, nothing)
# end
# end
# VA[sym, j], ODESolution_getindex_pullback
# end
#
# Zygote.@adjoint function Base.getindex(VA::DiffEqBase.ODESolution, sym)
# function ODESolution_getindex_pullback(Δ)
# i = SciMLBase.issymbollike(sym) ? SciMLBase.sym_to_index(sym, VA) : sym
# if i === nothing
# throw("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")
# else
# Δ′ = [ [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] for (x, j) in zip(VA.u, 1:length(VA))]
# (Δ′, nothing)
# end
# end
# VA[sym], ODESolution_getindex_pullback
# end
#
Zygote.@adjoint function Base.getindex(sim::DiffEqBase.EnsembleSolution, i::Int) #where {T,N,S}
function EnsembleSolution_getindex_pullback(Δ::ODESolution)
# prob = sim[1].prob
# du = zeros(eltype(Δ.u[1]),size(Δ.u))
# dp = zeros(eltype(prob.p),length(sim[1].prob.p))
# eprob = remake(prob, u=du, p=dp)
# empty_sol = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
# typeof(eprob), Nothing, Nothing, Nothing}(du, nothing,
# nothing, nothing, nothing, eprob, nothing, nothing,
# Δ.dense, 0, nothing, Δ.retcode)
# arr = [t == i ? Δ : empty_sol for t in 1:length(sim)]
# arr = [t == i ? Δ : sim[t] for t in 1:length(sim)]
arr = [t == i ? Δ : Δ for t in 1:length(sim)]
(EnsembleSolution(arr, 0.0, true), nothing)
end
sim[i], EnsembleSolution_getindex_pullback
end