Skip to content

Commit 5031265

Browse files
committed
Merge tag 'v0.6.59' into muse
[Diff since v0.6.58](FluxML/Zygote.jl@v0.6.58...v0.6.59) **Merged pull requests:** - Actually make sure conda env dir is set on Buildkite CI (FluxML#1392) (@ToucheSir) - generated z2d (FluxML#1394) (@chengchingwen) - bump version (FluxML#1398) (@chengchingwen)
2 parents 9cdf7fd + 1c933e9 commit 5031265

30 files changed

+395
-292
lines changed

.buildkite/pipeline.yml

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ steps:
44
- JuliaCI/julia#v1:
55
version: "1.6"
66
- JuliaCI/julia-test#v1: ~
7+
command:
8+
- mkdir -p "$${JULIA_DEPOT_PATH}/conda/3/x86_64"
79
agents:
810
queue: "juliagpu"
911
cuda: "*"
@@ -14,6 +16,8 @@ steps:
1416
- JuliaCI/julia#v1:
1517
version: "1"
1618
- JuliaCI/julia-test#v1: ~
19+
command:
20+
- mkdir -p "$${JULIA_DEPOT_PATH}/conda/3/x86_64"
1721
agents:
1822
queue: "juliagpu"
1923
cuda: "*"

.github/workflows/Downstream.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ jobs:
2727
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
2828
- {user: JuliaMolSim, repo: Molly.jl, group: Zygote}
2929
steps:
30-
- uses: actions/checkout@v2
30+
- uses: actions/checkout@v3
3131
- uses: julia-actions/setup-julia@v1
3232
with:
3333
version: ${{ matrix.julia-version }}
3434
arch: x64
3535
- uses: julia-actions/julia-buildpkg@latest
3636
- name: Clone Downstream
37-
uses: actions/checkout@v2
37+
uses: actions/checkout@v3
3838
with:
3939
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
4040
path: downstream

.github/workflows/TagBot.yml

+16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,22 @@ on:
44
types:
55
- created
66
workflow_dispatch:
7+
inputs:
8+
lookback:
9+
default: 3
10+
permissions:
11+
actions: read
12+
checks: read
13+
contents: write
14+
deployments: read
15+
issues: read
16+
discussions: read
17+
packages: read
18+
pages: read
19+
pull-requests: read
20+
repository-projects: read
21+
security-events: read
22+
statuses: read
723
jobs:
824
TagBot:
925
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'

.github/workflows/ci.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ jobs:
3232
# version: '1'
3333
# arch: x64
3434
steps:
35-
- uses: actions/checkout@v2
35+
- uses: actions/checkout@v3
3636
- uses: julia-actions/setup-julia@v1
3737
with:
3838
version: ${{ matrix.version }}
3939
arch: ${{ matrix.arch }}
40-
- uses: actions/cache@v1
40+
- uses: actions/cache@v3
4141
env:
4242
cache-name: cache-artifacts
4343
with:
@@ -58,15 +58,15 @@ jobs:
5858
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
5959
- uses: julia-actions/julia-processcoverage@v1
6060
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
61-
- uses: codecov/codecov-action@v2
61+
- uses: codecov/codecov-action@v3
6262
if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
6363
with:
6464
file: lcov.info
6565
docs:
6666
name: Documentation
6767
runs-on: ubuntu-latest
6868
steps:
69-
- uses: actions/checkout@v2
69+
- uses: actions/checkout@v3
7070
- uses: julia-actions/setup-julia@v1
7171
with:
7272
version: '1'

.github/workflows/clean_preview.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
runs-on: ubuntu-latest
1111
steps:
1212
- name: Checkout gh-pages branch
13-
uses: actions/checkout@v2
13+
uses: actions/checkout@v3
1414
with:
1515
ref: gh-pages
1616
- name: Delete preview and history + push changes

Project.toml

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.55"
3+
version = "0.6.59"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -20,17 +20,30 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2020
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
23+
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
2324
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2425
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2526
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2627
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2728

29+
[weakdeps]
30+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
31+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
32+
Tracker= "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
33+
34+
[extensions]
35+
ZygoteColorsExt = "Colors"
36+
ZygoteDistancesExt = "Distances"
37+
ZygoteTrackerExt = "Tracker"
38+
2839
[compat]
29-
AbstractFFTs = "0.5, 1.0"
40+
AbstractFFTs = "1.3.1"
3041
ChainRules = "1.44.1"
3142
ChainRulesCore = "1.9"
3243
ChainRulesTestUtils = "1"
44+
Colors = "0.12"
3345
DiffRules = "1.4"
46+
Distances = "0.10"
3447
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
3548
ForwardDiff = "0.10"
3649
GPUArrays = "8.4.2"
@@ -40,17 +53,23 @@ LogExpFunctions = "0.3.1"
4053
MacroTools = "0.5"
4154
NaNMath = "0.3, 1"
4255
Requires = "1.1"
56+
SnoopPrecompile = "1.0.3"
4357
SpecialFunctions = "1.6, 2"
58+
Tracker = "0.2"
4459
ZygoteRules = "0.2.1"
4560
julia = "1.6"
4661

4762
[extras]
63+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
64+
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
4865
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4966
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5067
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
5168
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
5269
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
70+
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
5371
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
72+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5473

5574
[targets]
56-
test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "Test"]
75+
test = ["ChainRulesTestUtils", "Conda", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"]

docs/src/utils.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ or a Hessian (by taking a second derivative).
66
```@docs
77
Zygote.jacobian
88
Zygote.hessian
9+
Zygote.hessian_reverse
910
Zygote.diaghessian
1011
```
1112

ext/ZygoteColorsExt.jl

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ZygoteColorsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Colors
6+
else
7+
using ..Zygote
8+
using ..Colors
9+
end
10+
11+
Zygote.@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
12+
13+
end

src/lib/distances.jl renamed to ext/ZygoteDistancesExt.jl

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
using .Distances
1+
module ZygoteDistancesExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Distances
6+
using LinearAlgebra
7+
else
8+
using ..Zygote
9+
using ..Distances
10+
using ..LinearAlgebra
11+
end
12+
13+
using Zygote: @adjoint, AContext, _pullback
214

315
@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
416
δ = x .- y
@@ -66,22 +78,34 @@ end
6678

6779
_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)
6880

69-
@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
81+
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
82+
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
83+
X::AbstractMatrix, Y::AbstractMatrix)
7084
# Modify the forwards-pass slightly to ensure stability on the reverse.
85+
dims = kws.dims
7186
function _pairwise_euclidean(sqdist::SqEuclidean, X, Y)
7287
D2 = pairwise(sqdist, X, Y; dims=dims)
7388
δ = eps(eltype(D2))
7489
return _sqrt_if_positive.(D2, δ)
7590
end
76-
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
91+
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
92+
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
93+
return res, pairwise_Euclidean_pullback
7794
end
7895

79-
@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2)
96+
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
97+
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
98+
X::AbstractMatrix)
8099
# Modify the forwards-pass slightly to ensure stability on the reverse.
100+
dims = kws.dims
81101
function _pairwise_euclidean(sqdist::SqEuclidean, X)
82102
D2 = pairwise(sqdist, X; dims=dims)
83103
δ = eps(eltype(D2))
84104
return _sqrt_if_positive.(D2, δ)
85105
end
86-
return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X)
106+
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
107+
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
108+
return res, pairwise_Euclidean_pullback
109+
end
110+
87111
end

ext/ZygoteTrackerExt.jl

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module ZygoteTrackerExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Tracker: Tracker, TrackedArray, TrackedReal
6+
else
7+
using ..Zygote
8+
using ..Tracker: Tracker, TrackedArray, TrackedReal
9+
end
10+
11+
Zygote.unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)
12+
13+
Zygote.pullback(f, ps::Tracker.Params) = pullback(f, ZygtParams(ps))
14+
Tracker.forward(f, ps::Params) = Tracker.forward(f, Tracker.Params(ps))
15+
Tracker.gradient_(f, ps::Params) = Tracker.gradient_(f, Tracker.Params(ps))
16+
17+
end

src/Zygote.jl

+11-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra, Statistics
44
using LinearAlgebra: copytri!, AbstractTriangular
55

66
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
7-
literal_getproperty, literal_getfield
7+
literal_getproperty, literal_getfield, unthunk_tangent
88

99
using ChainRulesCore
1010
using ChainRules: ChainRules, rrule, unthunk, canonicalize
@@ -43,7 +43,6 @@ include("lib/forward.jl")
4343
include("lib/utils.jl")
4444
include("lib/range.jl")
4545
include("lib/logexpfunctions.jl")
46-
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
4746

4847
# we need to define this late, so that the genfuncs see lib.jl
4948
# Move using statements out of this file to help with sysimage building
@@ -53,12 +52,11 @@ include("compiler/interface2.jl")
5352

5453
include("profiler/Profile.jl")
5554

56-
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
57-
include("flux.jl")
58-
end
5955

60-
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
61-
@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
56+
if !isdefined(Base, :get_extension)
57+
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("../ext/ZygoteDistancesExt.jl")
58+
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/ZygoteTrackerExt.jl")
59+
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" include("../ext/ZygoteColorsExt.jl")
6260
end
6361

6462
using InteractiveUtils
@@ -79,8 +77,11 @@ macro profile(ex)
7977
end
8078
end
8179

82-
## reverted due to https://github.com/SciML/DiffEqFlux.jl/issues/783
83-
# using SnoopPrecompile
84-
# @precompile_all_calls precompile()
80+
using SnoopPrecompile
81+
# This caused freezes on early 1.8 patch versions,
82+
# see https://github.com/SciML/DiffEqFlux.jl/issues/783
83+
@static if VERSION < v"1.8" || VERSION >= v"1.8.5"
84+
@precompile_all_calls precompile()
85+
end
8586

8687
end # module

src/compiler/chainrules.jl

+26-3
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,39 @@ end
312312
# Right now it uses a NamedTuple but not for fields of the AbstractDict struct
313313
z2d(dx::NamedTuple, primal::AbstractDict) = dx
314314

315-
function z2d(delta::NamedTuple, primal::T) where T # arbitrart struct
315+
function _z2d_struct_fallback(delta::NamedTuple, primal::T) where T
316316
fnames = fieldnames(T)
317317
deltas = map(n -> get(delta, n, nothing), fnames)
318318
primals = map(n -> getfield(primal, n), fnames)
319319
inner = map(z2d, deltas, primals) # recurse into fields
320-
if inner isa Tuple{Vararg{AbstractZero}}
320+
if inner isa Tuple{Vararg{AbstractZero}}
321321
return NoTangent() # collapse all-zero case
322322
else
323323
backing = NamedTuple{fnames}(inner)
324-
return canonicalize(Tangent{T, typeof(backing)}(backing))
324+
return Tangent{T, typeof(backing)}(backing)
325+
end
326+
end
327+
328+
function z2d(delta::NamedTuple, primal::T) where T # arbitrart struct
329+
if @generated
330+
fnames = fieldnames(T)
331+
N = length(fnames)
332+
deltas = [ :($(Symbol(:delta_, fname)) = get(delta, $(QuoteNode(fname)), nothing)) for fname in fnames ]
333+
primals = [ :($(Symbol(:primal_, fname)) = getfield(primal, $(QuoteNode(fname)))) for fname in fnames ]
334+
inner = Expr(:tuple, [ :(z2d($(Symbol(:delta_, fname)), $(Symbol(:primal_, fname)))) for fname in fnames ]...)
335+
return quote
336+
$(deltas...)
337+
$(primals...)
338+
inner = $inner
339+
if inner isa Tuple{Vararg{AbstractZero}}
340+
return NoTangent() # collapse all-zero case
341+
else
342+
backing = NamedTuple{$fnames}(inner)
343+
return Tangent{T, typeof(backing)}(backing)
344+
end
345+
end
346+
else
347+
return _z2d_struct_fallback(delta, primal)
325348
end
326349
end
327350

0 commit comments

Comments
 (0)