44
55ChainRules. @non_differentiable (:: Type{T} where {T<: Array })(:: UndefInitializer , args... )
66
7+ function frule ((_, ẋ), :: Type{T} , x:: AbstractArray ) where {T<: Array }
8+ return T (x), T (ẋ)
9+ end
10+
11+ function frule ((_, ẋ), :: Type{AbstractArray{T}} , x:: AbstractArray ) where {T}
12+ return AbstractArray {T} (x), AbstractArray {T} (ẋ)
13+ end
14+
715function rrule (:: Type{T} , x:: AbstractArray ) where {T<: Array }
816 project_x = ProjectTo (x)
917 Array_pullback (ȳ) = (NoTangent (), project_x (ȳ))
1018 return T (x), Array_pullback
1119end
1220
21+ # This abstract one is used for `float(x)` and other float conversion purposes:
22+ function rrule (:: Type{AbstractArray{T}} , x:: AbstractArray ) where {T}
23+ project_x = ProjectTo (x)
24+ AbstractArray_pullback (ȳ) = (NoTangent (), project_x (ȳ))
25+ return AbstractArray {T} (x), AbstractArray_pullback
26+ end
27+
1328# ####
1429# #### `vect`
1530# ####
1631
1732@non_differentiable Base. vect ()
1833
34+ function frule ((_, ẋs... ), :: typeof (Base. vect), xs:: Number... )
35+ return Base. vect (xs... ), Base. vect (_instantiate_zeros (ẋs, xs)... )
36+ end
37+
1938# Case of uniform type `T`: the data passes straight through,
2039# so no projection should be required.
2140function rrule (:: typeof (Base. vect), X:: Vararg{T, N} ) where {T, N}
@@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
4362 return Base. vect (X... ), vect_pullback
4463end
4564
65+ """
66+ _instantiate_zeros(ẋs, xs)
67+
68+ Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s.
69+ To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this
70+ materialises each zero `ẋ` to be `zero(x)`.
71+ """
72+ _instantiate_zeros (ẋs, xs) = map (_i_zero, ẋs, xs)
73+ _i_zero (ẋ, x) = ẋ
74+ _i_zero (ẋ:: AbstractZero , x) = zero (x)
75+ # Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)`
76+ # may give a MethodError for `zero` but won't be wrong.
77+
78+ # Fast paths. Should it also collapse all-Zero cases?
79+ _instantiate_zeros (ẋs:: Tuple{Vararg{<:Number}} , xs) = ẋs
80+ _instantiate_zeros (ẋs:: Tuple{Vararg{<:AbstractArray}} , xs) = ẋs
81+ _instantiate_zeros (ẋs:: AbstractArray{<:Number} , xs) = ẋs
82+ _instantiate_zeros (ẋs:: AbstractArray{<:AbstractArray} , xs) = ẋs
83+
84+ # ####
85+ # #### `copyto!`
86+ # ####
87+
88+ function frule ((_, ẏ, ẋ), :: typeof (copyto!), y:: AbstractArray , x)
89+ return copyto! (y, x), copyto! (ẏ, ẋ)
90+ end
91+
92+ function frule ((_, ẏ, _, ẋ), :: typeof (copyto!), y:: AbstractArray , i:: Integer , x, js:: Integer... )
93+ return copyto! (y, i, x, js... ), copyto! (ẏ, i, ẋ, js... )
94+ end
95+
4696# ####
4797# #### `reshape`
4898# ####
4999
50- function rrule (:: typeof (reshape), A:: AbstractArray , dims:: Tuple{Vararg{Union{Colon,Int}}} )
51- A_dims = size (A)
52- function reshape_pullback (Ȳ)
53- return (NoTangent (), reshape (Ȳ, A_dims), NoTangent ())
54- end
55- return reshape (A, dims), reshape_pullback
100+ function frule ((_, ẋ), :: typeof (reshape), x:: AbstractArray , dims... )
101+ return reshape (x, dims... ), reshape (ẋ, dims... )
56102end
57103
58- function rrule (:: typeof (reshape), A:: AbstractArray , dims:: Union{Colon,Int} ...)
59- A_dims = size (A)
60- function reshape_pullback (Ȳ)
61- ∂A = reshape (Ȳ, A_dims)
62- ∂dims = broadcast (Returns (NoTangent ()), dims)
63- return (NoTangent (), ∂A, ∂dims... )
64- end
104+ function rrule (:: typeof (reshape), A:: AbstractArray , dims... )
105+ ax = axes (A)
106+ project = ProjectTo (A) # Projection is here for e.g. reshape(::Diagonal, :)
107+ ∂dims = broadcast (Returns (NoTangent ()), dims)
108+ reshape_pullback (Ȳ) = (NoTangent (), project (reshape (Ȳ, ax)), ∂dims... )
65109 return reshape (A, dims... ), reshape_pullback
66110end
67111
112+ # ####
113+ # #### `dropdims`
114+ # ####
115+
116+ function frule ((_, ẋ), :: typeof (dropdims), x:: AbstractArray ; dims)
117+ return dropdims (x; dims), dropdims (ẋ; dims)
118+ end
119+
120+ function rrule (:: typeof (dropdims), A:: AbstractArray ; dims)
121+ ax = axes (A)
122+ project = ProjectTo (A)
123+ dropdims_pullback (Ȳ) = (NoTangent (), project (reshape (Ȳ, ax)))
124+ return dropdims (A; dims), dropdims_pullback
125+ end
126+
68127# ####
69128# #### `permutedims`
70129# ####
71130
131+ function frule ((_, ẋ), :: typeof (permutedims), x:: AbstractArray , perm... )
132+ return permutedims (x, perm... ), permutedims (ẋ, perm... )
133+ end
134+
135+ function frule ((_, ẏ, ẋ), :: typeof (permutedims!), y:: AbstractArray , x:: AbstractArray , perm... )
136+ return permutedims! (y, x, perm... ), permutedims! (ẏ, ẋ, perm... )
137+ end
138+
139+ function frule ((_, ẋ), :: Type{<:PermutedDimsArray} , x:: AbstractArray , perm)
140+ return PermutedDimsArray (x, perm), PermutedDimsArray (ẋ, perm)
141+ end
142+
72143function rrule (:: typeof (permutedims), x:: AbstractVector )
73144 project = ProjectTo (x)
74145 permutedims_pullback_1 (dy) = (NoTangent (), project (permutedims (unthunk (dy))))
91162# #### `repeat`
92163# ####
93164
165+ function frule ((_, ẋs), :: typeof (repeat), xs:: AbstractArray , cnt... ; kw... )
166+ return repeat (xs, cnt... ; kw... ), repeat (ẋs, cnt... ; kw... )
167+ end
168+
94169function rrule (:: typeof (repeat), xs:: AbstractArray ; inner= ntuple (Returns (1 ), ndims (xs)), outer= ntuple (Returns (1 ), ndims (xs)))
95170
96171 project_Xs = ProjectTo (xs)
130205# #### `hcat`
131206# ####
132207
208+ function frule ((_, ẋs... ), :: typeof (hcat), xs... )
209+ return hcat (xs... ), hcat (_instantiate_zeros (ẋs, xs)... )
210+ end
211+
133212function rrule (:: typeof (hcat), Xs:: Union{AbstractArray, Number} ...)
134213 Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
135214 ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
@@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
164243 return Y, hcat_pullback
165244end
166245
246+ function frule ((_, _, Ȧs), :: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
247+ return reduce (hcat, As), reduce (hcat, _instantiate_zeros (Ȧs, As))
248+ end
249+
167250function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
168251 widths = map (A -> size (A,2 ), As)
169252 function reduce_hcat_pullback_2 (dY)
192275# #### `vcat`
193276# ####
194277
278+ function frule ((_, ẋs... ), :: typeof (vcat), xs... )
279+ return vcat (xs... ), vcat (_instantiate_zeros (ẋs, xs)... )
280+ end
281+
195282function rrule (:: typeof (vcat), Xs:: Union{AbstractArray, Number} ...)
196283 Y = vcat (Xs... )
197284 ndimsY = Val (ndims (Y))
@@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
224311 return Y, vcat_pullback
225312end
226313
314+ function frule ((_, _, Ȧs), :: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
315+ return reduce (vcat, As), reduce (vcat, _instantiate_zeros (Ȧs, As))
316+ end
317+
227318function rrule (:: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
228319 Y = reduce (vcat, As)
229320 ndimsY = Val (ndims (Y))
247338
248339_val (:: Val{x} ) where {x} = x
249340
341+ function frule ((_, ẋs... ), :: typeof (cat), xs... ; dims)
342+ return cat (xs... ; dims), cat (_instantiate_zeros (ẋs, xs)... ; dims)
343+ end
344+
250345function rrule (:: typeof (cat), Xs:: Union{AbstractArray, Number} ...; dims)
251346 Y = cat (Xs... ; dims= dims)
252347 cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
285380# #### `hvcat`
286381# ####
287382
383+ function frule ((_, _, ẋs... ), :: typeof (hvcat), rows, xs... )
384+ return hvcat (rows, xs... ), hvcat (rows, _instantiate_zeros (ẋs, xs)... )
385+ end
386+
288387function rrule (:: typeof (hvcat), rows, values:: Union{AbstractArray, Number} ...)
289388 Y = hvcat (rows, values... )
290389 cols = size (Y,2 )
321420# 1-dim case allows start/stop, N-dim case takes dims keyword
322421# whose defaults changed in Julia 1.6... just pass them all through:
323422
324- function frule ((_, xdot), :: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
325- return reverse (x, args... ; kw... ), reverse (xdot, args... ; kw... )
423+ function frule ((_, ẋ), :: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
424+ return reverse (x, args... ; kw... ), reverse (ẋ, args... ; kw... )
425+ end
426+
427+ function frule ((_, ẋ), :: typeof (reverse!), x:: Union{AbstractArray, Tuple} , args... ; kw... )
428+ return reverse! (x, args... ; kw... ), reverse! (ẋ, args... ; kw... )
326429end
327430
328431function rrule (:: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
338441# #### `circshift`
339442# ####
340443
341- function frule ((_, xdot), :: typeof (circshift), x:: AbstractArray , shifts)
342- return circshift (x, shifts), circshift (xdot, shifts)
444+ function frule ((_, ẋ), :: typeof (circshift), x:: AbstractArray , shifts)
445+ return circshift (x, shifts), circshift (ẋ, shifts)
446+ end
447+
448+ function frule ((_, ẏ, ẋ), :: typeof (circshift!), y:: AbstractArray , x:: AbstractArray , shifts)
449+ return circshift! (y, x, shifts), circshift! (ẏ, ẋ, shifts)
343450end
344451
345452function rrule (:: typeof (circshift), x:: AbstractArray , shifts)
355462# #### `fill`
356463# ####
357464
358- function frule ((_, xdot), :: typeof (fill), x:: Any , dims... )
359- return fill (x, dims... ), fill (xdot, dims... )
465+ function frule ((_, ẋ), :: typeof (fill), x:: Any , dims... )
466+ return fill (x, dims... ), fill (ẋ, dims... )
467+ end
468+
469+ function frule ((_, ẏ, ẋ), :: typeof (fill!), y:: AbstractArray , x:: Any )
470+ return fill! (y, x), fill! (ẏ, ẋ)
360471end
361472
362473function rrule (:: typeof (fill), x:: Any , dims... )
370481# #### `filter`
371482# ####
372483
373- function frule ((_, _, xdot ), :: typeof (filter), f, x:: AbstractArray )
484+ function frule ((_, _, ẋ ), :: typeof (filter), f, x:: AbstractArray )
374485 inds = findall (f, x)
375- return x[inds], xdot [inds]
486+ return x[inds], ẋ [inds]
376487end
377488
378489function rrule (:: typeof (filter), f, x:: AbstractArray )
392503for findm in (:findmin , :findmax )
393504 findm_pullback = Symbol (findm, :_pullback )
394505
395- @eval function frule ((_, xdot ), :: typeof ($ findm), x; dims= :)
506+ @eval function frule ((_, ẋ ), :: typeof ($ findm), x; dims= :)
396507 y, ind = $ findm (x; dims= dims)
397- return (y, ind), Tangent {typeof((y, ind))} (xdot [ind], NoTangent ())
508+ return (y, ind), Tangent {typeof((y, ind))} (ẋ [ind], NoTangent ())
398509 end
399510
400511 @eval function rrule (:: typeof ($ findm), x:: AbstractArray ; dims= :)
441552# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
442553# these rules are the reason it takes a `dims` argument.
443554
444- function frule ((_, _, dydot ), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
445- return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dydot , dims, inds... )
555+ function frule ((_, _, dẏ ), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
556+ return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dẏ , dims, inds... )
446557end
447558
448559function rrule (:: typeof (_zerolike_writeat), x, dy, dims, inds... )
457568
458569# These rules for `maximum` pick the same subgradient as `findmax`:
459570
460- function frule ((_, xdot ), :: typeof (maximum), x; dims= :)
571+ function frule ((_, ẋ ), :: typeof (maximum), x; dims= :)
461572 y, ind = findmax (x; dims= dims)
462- return y, xdot [ind]
573+ return y, ẋ [ind]
463574end
464575
465576function rrule (:: typeof (maximum), x:: AbstractArray ; dims= :)
@@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:)
468579 return y, maximum_pullback
469580end
470581
471- function frule ((_, xdot ), :: typeof (minimum), x; dims= :)
582+ function frule ((_, ẋ ), :: typeof (minimum), x; dims= :)
472583 y, ind = findmin (x; dims= dims)
473- return y, xdot [ind]
584+ return y, ẋ [ind]
474585end
475586
476587function rrule (:: typeof (minimum), x:: AbstractArray ; dims= :)
0 commit comments