diff --git a/.gitignore b/.gitignore
index c289756be2..dd9aa9ad8e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,7 +6,7 @@ docs/build/
docs/site/
deps
.vscode
-Manifest.toml
+Manifest*toml
LocalPreferences.toml
.DS_Store
docs/mymodel.bson
diff --git a/NEWS.md b/NEWS.md
index 4e2a3316c1..c9056742df 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -2,11 +2,71 @@
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
-## v0.16.1 (25 December 2025)
+## v0.16.8 (January 2025)
-The default `init_score` value for `early_stopping` has been set to `Inf` (instead of `0`) in order to prevent unexpected behavior if the defaults were not modified. Documentation has been updated to explain that, if the user needs to track a metric where improvement is shown by increasing values, then the `init_score` needs to be adjusted, for example, to `-Inf`). Tests corresponding to `early_stopping` have been reorganized and extended to be more detailed and illustrative of `early_stopping`'s behavior.
+This release includes the following changes:
+- Added support in `Flux.gradient` and `Flux.withgradient` to alternative AD backends such as `AutoEnzyme()` and `AutoMooncake()`.
+- The default `init_score` value for `early_stopping` has been set to `Inf` (instead of `0`) in order to prevent unexpected behavior if the defaults were not modified.
-## v0.16.0 (15 December 2025)
+## v0.16.7 (10 December 2025)
+
+This patch release includes:
+
+* Minor documentation fixes and housekeeping commits.
+* Compatibility updates for downstream packages.
+
+
+## v0.16.6 (8 December 2025)
+
+This patch release includes:
+
+* Minor dependency bumps and CI updates.
+* Preparatory changes ahead of v0.16.7.
+
+## v0.16.5 (23 July 2025)
+
+This release includes:
+
+* Fix typos in legacy tutorials documentation.([GitHub][2])
+* Bump compatibility for `AMDGPU` in weak dependencies.([GitHub][2])
+* **Fix** for `unsafe_free!` failure with certain `CuArray` configurations.([GitHub][2])
+
+## v0.16.4 (2 June 2025)
+
+This release includes:
+
+* Fix missing imports in `FluxMPIExt`.([GitHub][1])
+* Add shape validation for convolution weight tensors.([GitHub][1])
+* Disable and fix intermittent Reactant tests.([GitHub][1])
+* Fix recurrent docstrings and pooling layer loading.([GitHub][1])
+* Small test updates and miscellaneous doc fixes.([GitHub][1])
+
+## v0.16.3 (6 February 2025)
+
+This release includes:
+
+* **Fix** for `cpu(dataloader)` behavior.([GitHub][1])
+* Addressed data loading and preprocessing pipeline issues.([GitHub][1])
+* Resolved “Infinite time of gradient” edge case.([GitHub][1])
+
+## v0.16.2 (21 January 2025)
+
+This release includes:
+
+* Updated dependencies and bumped to v0.16.1 as a base.([GitHub][1])
+* **Fixes** around new gradients, precompilation on Julia 1.12, and export issues.([GitHub][1])
+
+## v0.16.1 (13 January 2025)
+
+This release includes:
+
+* Added references to recurrent layers in `ecosystem.md`.([GitHub][1])
+* Fixed typo in recurrence documentation.([GitHub][1])
+* Added “return state” option to recurrent layers.([GitHub][1])
+* Updated schedulers docs, collapsed docstrings in layers docs.([GitHub][1])
+* Test fixes for Enzyme and Reactant forward/reverse passes.([GitHub][1])
+
+## v0.16.0 (15 December 2024)
This release has a single **breaking change**:
- The recurrent cells `RNNCell`, `LSTMCell`, and `GRUCell` forward has been changed to
diff --git a/Project.toml b/Project.toml
index 23fe826646..0445c4d43f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -2,10 +2,8 @@ name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.16.7"
-[workspace]
-projects = ["test", "docs"]
-
[deps]
+ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -33,7 +31,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
+FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
+Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
@@ -42,10 +42,13 @@ FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
+FluxFiniteDifferencesExt = "FiniteDifferences"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
+FluxMooncakeExt = "Mooncake"
[compat]
+ADTypes = "1"
AMDGPU = "1, 2"
Adapt = "4"
CUDA = "5"
@@ -53,12 +56,14 @@ ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.13"
EnzymeCore = "0.7.7, 0.8.4"
+FiniteDifferences = "0.12"
Functors = "0.5"
MLCore = "1.0.0"
MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
+Mooncake = "0.4"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
@@ -72,3 +77,6 @@ Statistics = "1"
Zygote = "0.6.67, 0.7"
cuDNN = "1"
julia = "1.10"
+
+[workspace]
+projects = ["test", "docs"]
diff --git a/docs/Project.toml b/docs/Project.toml
index 0acc886a07..e9675ae73e 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -10,6 +10,7 @@ MLCore = "c2834f40-e789-41da-a90e-33b280584a8c"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
@@ -17,5 +18,8 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+[sources]
+Flux = {path = ".."}
+
[compat]
Documenter = "1.3"
diff --git a/docs/make.jl b/docs/make.jl
index 75f30494eb..f8d7d6cfb0 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,6 +1,7 @@
-using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
+using Documenter
+using Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
- DataFrames, JLD2, MLDataDevices, MLCore
+ DataFrames, JLD2, MLDataDevices, MLCore, Mooncake
using MLCore: numobs, getobs, getobs!
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
@@ -21,7 +22,8 @@ makedocs(
sidebar_sitename = false,
analytics = "UA-36890222-9",
assets = ["assets/flux.css"],
- prettyurls = get(ENV, "CI", nothing) == "true"
+ prettyurls = get(ENV, "CI", nothing) == "true",
+ size_threshold=1_000_000,
),
pages = [
"Welcome" => "index.md",
@@ -50,8 +52,7 @@ makedocs(
"Shape Inference" => "reference/outputsize.md",
"Flat vs. Nested" => "reference/destructure.md",
"Callback Helpers" => "reference/training/callbacks.md",
- "Gradients -- Zygote.jl" => "reference/training/zygote.md",
- "Gradients -- Enzyme.jl" => "reference/training/enzyme.md",
+ "Gradients" => "reference/training/gradients.md",
"Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md",
"Batching Data -- MLUtils.jl" => "reference/data/mlutils.md",
"OneHotArrays.jl" => "reference/data/onehot.md",
diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md
index 27a958074c..9750c034d9 100644
--- a/docs/src/guide/models/basics.md
+++ b/docs/src/guide/models/basics.md
@@ -181,9 +181,7 @@ These matching nested structures are at the core of how Flux works.
This method of `gradient` takes a zero-argument function, which only *implicitly*
depends on `θ`.
-```@raw html
-
-```
+## Automatic Differentiation
Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
@@ -198,7 +196,7 @@ Flux can also be used with other automatic differentiation (AD) packages.
It was originally written using [Tracker](https://github.com/FluxML/Tracker.jl), a more traditional operator-overloading approach.
The future might be [Enzyme](https://github.com/EnzymeAD/Enzyme.jl), and Flux now builds in an easy way to use this instead, turned on by wrapping the model in `Duplicated`. (For details, see the [Enzyme page](@ref autodiff-enzyme) in the manual.)
-```julia
+```julia-repl
julia> using Enzyme: Const, Duplicated
julia> grad3e = Flux.gradient((x,p) -> p(x), Const(5.0), Duplicated(poly3s))
@@ -210,13 +208,33 @@ Here, this is because `Const(5.0)` is explicitly constant.
Below, we will see an example where `nothing` shows up because the model struct has fields containing things other than parameters, such as an activation function.
(It also adopts the convention that `gradient(f, x, y)` returns a tuple `(∂f/∂x, ∂f/∂y)`, without a "`∂f/∂f`" term for the function. This is why we had to write `gradient(|>, 5, poly4)` above, not just `gradient(poly4, 5)`.)
-Finally, the function [`withgradient`](@ref) works the same way, but also returns the value of the function:
+The function [`withgradient`](@ref) works the same way, but also returns the value of the function:
```jldoctest poly
julia> Flux.withgradient((x,p) -> p(x), 5.0, poly3s)
(val = 17.5, grad = (2.0, (θ3 = [1.0, 5.0, 25.0],)))
```
+One can also directly specify which AD backend to use, by passing an adtype among the supported ones
+(`AutoMooncake, AutoEnzyme, AutoZygote, AutoFiniteDifferences`) as the second argument.
+The corresponding AD package has to be loaded first.
+
+Here is an example using [Mooncake](https://github.com/chalk-lab/Mooncake.jl):
+```jldoctest poly
+julia> using Mooncake
+
+julia> Flux.withgradient((x,p) -> p(x), AutoMooncake(), 5.0, poly3s)
+(val = 17.5, grad = (2.0, Poly3{Vector{Float64}}([1.0, 5.0, 25.0])))
+```
+
+and here is the same example using Enzyme:
+```julia-repl
+julia> using Enzyme
+
+julia> Flux.withgradient((x,p) -> p(x), AutoEnzyme(), 5.0, poly3s)
+(val = 17.5, grad = (2.0, Poly3{Vector{Float64}}([1.0, 5.0, 25.0])))
+```
+
## Simple Neural Networks
The polynomial functions above send a number `x` to another a number `y`.
diff --git a/docs/src/reference/data/mlutils.md b/docs/src/reference/data/mlutils.md
index 3d4c56f0d0..12dbd1c00f 100644
--- a/docs/src/reference/data/mlutils.md
+++ b/docs/src/reference/data/mlutils.md
@@ -25,6 +25,7 @@ these functions help create inputs for your models or batch your dataset.
MLUtils.batch
MLUtils.batchsize
MLUtils.batchseq
+MLUtils.batch_sequence
MLUtils.BatchView
MLUtils.chunk
MLUtils.eachobs
diff --git a/docs/src/reference/training/enzyme.md b/docs/src/reference/training/gradients.md
similarity index 66%
rename from docs/src/reference/training/enzyme.md
rename to docs/src/reference/training/gradients.md
index 77875c9da3..bdda5535d9 100644
--- a/docs/src/reference/training/enzyme.md
+++ b/docs/src/reference/training/gradients.md
@@ -1,5 +1,64 @@
+```@meta
+CollapsedDocStrings = true
+```
+
+# Automatic Differentiation in Flux
+
+Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model.
+Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
-# [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme)
+Flux also has support for Enzyme.jl, documented [below](@ref autodiff-enzyme) and for Mooncake.jl.
+
+
+## Generic Gradient Interface
+
+```@docs
+Flux.gradient(f, adtype::AbstractADType, args::Any...)
+Flux.withgradient(f, adtype::AbstractADType, args::Any...)
+```
+
+## [Automatic Differentiation using Zygote.jl](@id autodiff-zygote)
+
+The default AD backend in Flux is Zygote. Besides gradient calculation, Zygote also supports
+higher-order derivatives, Jacobians, Hessians, and pullbacks.
+
+```@docs
+Zygote.jacobian(f, args...)
+Zygote.withjacobian(f, args...)
+Zygote.hessian
+Zygote.hessian_reverse
+Zygote.diaghessian
+Zygote.pullback
+```
+
+## ChainRules for Zygote
+
+Zygote uses [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) to define how to differentiate functions.
+
+Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation.
+This can be done using the following methods:
+
+```@docs
+ChainRulesCore.ignore_derivatives
+ChainRulesCore.@non_differentiable
+```
+
+To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works.
+
+```@docs
+ChainRulesCore.rrule
+ChainRulesCore.frule
+ChainRulesCore.@scalar_rule
+ChainRulesCore.NoTangent
+ChainRulesCore.ZeroTangent
+ChainRulesCore.RuleConfig
+ChainRulesCore.Tangent
+ChainRulesCore.canonicalize
+```
+
+Gradient customization for other AD packages such as Enzyme and Mooncake has to be done according to their own documentation.
+
+## [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme)
[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation.
Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`.
@@ -71,25 +130,16 @@ true
Note that what `Enzyme.gradient` returns is an object like `deepcopy(model)` of the same type, `grads_e[1] isa Chain`.
But its fields contain the same gradient.
-There is also a method of `train!` which similarly takes `Duplicated(model)`:
-
-```julia-repl
-julia> opt_state = Flux.setup(Adam(0), model);
-
-julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
-```
-
-## Second-order AD
-
-If you calculate a gradient within the loss function, then training will involve 2nd derivatives.
-While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.
-
-## Listing
```@docs
Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
-Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
```
Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/).
+
+
+## Second-order AD
+
+If you calculate a gradient within the loss function, then training will involve 2nd derivatives.
+While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.
diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md
index a9ee35e4c0..6c825709e5 100644
--- a/docs/src/reference/training/reference.md
+++ b/docs/src/reference/training/reference.md
@@ -28,6 +28,19 @@ Optimisers.setup
To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl)
and follow its setup instructions.
+
+There is also a method of `train!` which similarly takes `Duplicated(model)` and uses Enzyme.jl for differentiation (see (@ref autodiff-enzyme)):
+```julia-repl
+julia> opt_state = Flux.setup(Adam(0), model);
+
+julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
+```
+
+```@docs
+Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
+```
+
+
## Optimisation Modifiers
The state returned by `setup` can be modified to temporarily prevent training of
diff --git a/docs/src/reference/training/zygote.md b/docs/src/reference/training/zygote.md
deleted file mode 100644
index ddf65917a1..0000000000
--- a/docs/src/reference/training/zygote.md
+++ /dev/null
@@ -1,48 +0,0 @@
-```@meta
-CollapsedDocStrings = true
-```
-
-# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote)
-
-Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model.
-Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
-
-Flux also has support for Enzyme.jl, documented [on its own page](@ref autodiff-enzyme).
-
-## Explicit style
-
-The preferred way of using Zygote, and the only way of using most other AD packages,
-is to explicitly provide a function and its arguments.
-
-```@docs
-Zygote.gradient(f, args...)
-Zygote.withgradient(f, args...)
-Zygote.jacobian(f, args...)
-Zygote.withjacobian(f, args...)
-Zygote.hessian
-Zygote.hessian_reverse
-Zygote.diaghessian
-Zygote.pullback
-```
-
-## ChainRules
-
-Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl):
-
-```@docs
-ChainRulesCore.ignore_derivatives
-ChainRulesCore.@non_differentiable
-```
-
-To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works.
-
-```@docs
-ChainRulesCore.rrule
-ChainRulesCore.frule
-ChainRulesCore.@scalar_rule
-ChainRulesCore.NoTangent
-ChainRulesCore.ZeroTangent
-ChainRulesCore.RuleConfig
-ChainRulesCore.Tangent
-ChainRulesCore.canonicalize
-```
diff --git a/ext/FluxEnzymeExt.jl b/ext/FluxEnzymeExt.jl
new file mode 100644
index 0000000000..0969c854d8
--- /dev/null
+++ b/ext/FluxEnzymeExt.jl
@@ -0,0 +1,115 @@
+module FluxEnzymeExt
+
+using Flux
+import Flux.Train: _enzyme_train!
+
+import Optimisers
+import Functors
+import Enzyme
+using Enzyme: EnzymeCore, EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed
+using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
+using ProgressLogging: @withprogress, @logprogress
+
+EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
+
+### gradient & withgradient
+function Flux.gradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}; zero::Bool=true) where {F,N}
+ return _enzyme_gradient(f, map(_trymake_duplicated, x)...; zero)
+end
+
+function Flux.withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}; zero::Bool=true) where {F,N}
+ return _enzyme_withgradient(f, map(_trymake_duplicated, x)...; zero)
+end
+
+_trymake_duplicated(x::EnzymeCore.Duplicated) = x
+_trymake_duplicated(x::EnzymeCore.Const) = x
+_trymake_duplicated(x::EnzymeCore.Active) = throw(ArgumentError("Enzyme's `Active` type not supported in `Flux.gradient` or `Flux.withgradient`."))
+_trymake_duplicated(x) = EnzymeCore.Duplicated(x, EnzymeCore.make_zero(x))
+
+
+function _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
+ for x in args
+ zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
+ _check_mutable(x)
+ end
+ ad = Enzyme.set_runtime_activity(Reverse)
+ Enzyme.autodiff(ad, Const(f), Active, args...)
+ return map(_grad_or_nothing, args)
+end
+
+_check_mutable(x::Const) = nothing
+_check_mutable(x::Duplicated) = Functors.anymutable(x) || error(
+ """`Flux.gradient(f, Duplicated(x), ...)` expects `x` to contain mutable parameter arrays."""
+)
+
+# This function strips the returned gradient to be Zygote-like:
+_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
+_grad_or_nothing(::Const) = nothing
+_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing
+
+function _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
+ for x in args
+ zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval)
+ _check_mutable(x)
+ end
+
+ # In order to support auxillary outputs, we try different ways.
+
+ ## Take I, doesn't allow for aux at all.
+ ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
+ _, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...)
+
+ ## Take II, using split mode.
+ ## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897
+ # forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
+ # tape, result, shadow_result = forward(Const(f), args...)
+ # reverse(Const(f), args..., _sensitivity(result), tape)
+
+ ## Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
+ ## This doesn't work with Reactant
+ # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
+ # ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
+ # _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
+
+ return (; val = result, grad = map(_grad_or_nothing, args))
+end
+
+## for Take II above
+# @inline _sensitivity(y::Real) = one(y)
+# @inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
+# @inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
+# _sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
+# or else a Tuple or NamedTuple whose first element is a real number.""")
+
+# for Take III above
+# function _ref_loss!(out::Ref, f, args...)
+# val = f(args...)
+# out[] = _get_loss(val) # saves loss by mutation
+# val # returns the whole thing
+# end
+# @inline _get_loss(y::Number) = y
+# @inline _get_loss(ys::Tuple{Number,Vararg}) = ys[1]
+# @inline _get_loss(ys::NamedTuple{S, <:Tuple{Number,Vararg}}) where S = ys[1]
+# _get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
+# or else a Tuple or NamedTuple whose first element is a real number.""")
+
+
+### Flux.Train, for train!
+
+function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
+ isnothing(cb) || error("""train! does not support callback functions.
+ For more control use a loop with `gradient` and `update!`.""")
+ @withprogress for (i,d) in enumerate(data)
+ d_splat = d isa Tuple ? d : (d,)
+ l, gs = Flux.withgradient(loss, AutoEnzyme(), model, map(Const, d_splat)...)
+ if !isfinite(l)
+ throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
+ end
+ opt, model2 = Optimisers.update!(opt, model.val, model.dval)
+ model = Duplicated(model2, model.dval)
+
+ @logprogress Base.haslength(data) ? i/length(data) : nothing
+ end
+end
+
+end # FluxEnzymeExt
diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl
deleted file mode 100644
index b124a4a973..0000000000
--- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl
+++ /dev/null
@@ -1,129 +0,0 @@
-module FluxEnzymeExt
-
-using Flux
-import Flux.Train: _enzyme_train!
-
-import Optimisers
-import Functors
-import Enzyme
-using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed
-using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
-using ProgressLogging: @withprogress, @logprogress
-
-EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
-
-### gradient & withgradient
-
-# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields
-# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 probably this can be `make_zero!(Ref(dup.dval))`
-_make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model)
-function _make_zero_inner!(x::AbstractArray{<:Number})
- Optimisers.isnumeric(x) || return
- Optimisers.maywrite(x) || error("can't handle this")
- fill!(x, zero(eltype(x)))
- nothing
-end
-_make_zero_inner!(x) = nothing # any other Functors leaf type
-
-#= # This _make_zero! matches what Flux allows elsewhere:
-julia> Flux.setup(Adam(), (1:3.)')
-ERROR: model must be fully mutable for `train!` to work, got `x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}`.
-If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) = true`
-=#
-# Perhaps canonical way for Enzyme is more like this:
-# function _make_zero!(x::AbstractArray{<:Number})
-# if Enzyme.guess_activity(typeof(x), Reverse) <: Duplicated
-# fill!(x, zero(eltype(x)))
-# elseif Enzyme.guess_activity(typeof(x), Reverse) <: Const
-# # that's OK
-# else
-# error("not sure what it should do for Active?")
-# end
-# end
-
-function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
- for x in args
- zero && x isa Duplicated && _make_zero!(x.dval)
- _check_mutable(x)
- end
- ad = Enzyme.set_runtime_activity(Reverse)
- Enzyme.autodiff(ad, Const(f), Active, args...)
- map(_grad_or_nothing, args)
-end
-
-_check_mutable(x::Const) = nothing
-_check_mutable(x::Duplicated) = Functors.anymutable(x) || error(
- """`Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays."""
-)
-
-# This function strips the returned gradient to be Zygote-like:
-_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
-_grad_or_nothing(::Const) = nothing
-_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing
-
-function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
- for x in args
- zero && x isa Duplicated && _make_zero!(x.dval)
- _check_mutable(x)
- end
-
- # Take I, doesn't allow for aux at all.
- # _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
-
- # Take II, using split mode.
- forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
- tape, result, shadow_result = forward(Const(f), args...)
- reverse(Const(f), args..., _sensitivity(result), tape)
-
- # Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
- # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
- # # result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
- # _, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
-
- (; val = result, grad = map(_grad_or_nothing, args))
-end
-
-@inline _sensitivity(y::Real) = one(y)
-@inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
-@inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
-_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
- or else a Tuple or NamedTuple whose first element is a real number.""")
-
-function _ref_loss!(out::Ref, f, args...) # for Take III above
- val = f(args...)
- out[] = _get_loss(val) # saves loss by mutation
- val # returns the whole thing
-end
-
-@inline _get_loss(y::Real) = y
-@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
-@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
-_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
- or else a Tuple or NamedTuple whose first element is a real number.""")
-
-### Flux.Train, for train!
-
-_applyloss(loss, model, d...) = loss(model, d...)
-
-function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
- isnothing(cb) || error("""train! does not support callback functions.
- For more control use a loop with `gradient` and `update!`.""")
- @withprogress for (i,d) in enumerate(data)
- d_splat = d isa Tuple ? d : (d,)
-
- _make_zero!(model.dval)
- ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
- _, l = Enzyme.autodiff(ad, _applyloss,
- Active, Const(loss), model, map(Const, d_splat)...)
-
- if !isfinite(l)
- throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
- end
- opt, model2 = Optimisers.update!(opt, model.val, model.dval)
- model = Duplicated(model2, model.dval)
-
- @logprogress Base.haslength(data) ? i/length(data) : nothing
- end
-end
-
-end # FluxEnzymeExt
diff --git a/ext/FluxFiniteDifferencesExt.jl b/ext/FluxFiniteDifferencesExt.jl
new file mode 100644
index 0000000000..836f401201
--- /dev/null
+++ b/ext/FluxFiniteDifferencesExt.jl
@@ -0,0 +1,33 @@
+module FluxFiniteDifferencesExt
+
+using Flux
+using ADTypes: AutoFiniteDifferences
+using FiniteDifferences
+
+function Flux.gradient(f::F, adtype::AutoFiniteDifferences, x) where F
+ ps, re = Flux.destructure(x)
+ gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1]
+ return (re(gs),)
+end
+
+function Flux.gradient(f::F, adtype::AutoFiniteDifferences, x::Vararg{Any,N}) where {F, N}
+ ps, re = Flux.destructure(x)
+ gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1]
+ return re(gs)
+end
+
+function Flux.withgradient(f::F, adtype::AutoFiniteDifferences, x) where F
+ ps, re = Flux.destructure(x)
+ y = f(re(ps)...)
+ gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1]
+ return y, (re(gs),)
+end
+
+function Flux.withgradient(f::F, adtype::AutoFiniteDifferences, x::Vararg{Any,N}) where {F, N}
+ ps, re = Flux.destructure(x)
+ y = f(re(ps)...)
+ gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1]
+ return y, re(gs)
+end
+
+end # module
diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl
new file mode 100644
index 0000000000..d371bfb952
--- /dev/null
+++ b/ext/FluxMooncakeExt.jl
@@ -0,0 +1,17 @@
+module FluxMooncakeExt
+
+using ADTypes: AutoMooncake
+using Mooncake: Mooncake
+import Flux
+
+function Flux.gradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N}
+ return Flux.withgradient(f, adtype, args...)[2]
+end
+
+function Flux.withgradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N}
+ cache = Mooncake.prepare_gradient_cache(f, args...; friendly_tangents=true)
+ val, grads = Mooncake.value_and_gradient!!(cache, f, args...)
+ return (val=val, grad=grads[2:end])
+end
+
+end # module
diff --git a/src/Flux.jl b/src/Flux.jl
index 1fb4f0ef6a..657d9b9a1c 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -25,6 +25,8 @@ using Zygote: @adjoint, pullback
using Zygote.ForwardDiff: value
using EnzymeCore: EnzymeCore
+@reexport using ADTypes # AutoZygote, AutoMooncake, etc...
+
@reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!,
default_device_rng,
gpu_device, cpu_device, xla_device,
@@ -64,8 +66,6 @@ export Chain, Dense, Embedding, EmbeddingBag,
freeze!, thaw!, adjust!, update!, trainable,
# from Zygote.jl
hessian, diaghessian, jacobian, withjacobian, pullback,
- # AD functions
- withgradient,
# init
glorot_uniform,
glorot_normal,
@@ -99,7 +99,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
))
include("gradient.jl")
-export gradient
+export gradient, withgradient
include("train.jl")
using .Train
diff --git a/src/gradient.jl b/src/gradient.jl
index aa88828545..8ee4d1eebf 100644
--- a/src/gradient.jl
+++ b/src/gradient.jl
@@ -1,6 +1,7 @@
+const SUPPORTED_AD_BACKENDS = (:Zygote, :Enzyme, :Mooncake, :FiniteDifferences)
"""
- gradient(f, args...)
+ gradient(f, [adtype,] args...)
Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or the gradient.
@@ -8,13 +9,21 @@ If no gradient is defined, `∂f/∂x` will be `nothing`.
`f(args...)` must be a real number, see [`Zygote.jacobian`](@ref) for array output.
-By default, `Flux.gradient` calls Zygote. If you load Enzyme, then other methods become available.
+The optional argument `adtype` allows specifying the automatic differentiation backend.
+
+We provide specific support and testing for the following backends:
+`AutoZygote`, `AutoEnzyme`, `AutoMooncake`, and `AutoFiniteDifferences`.
+
+The package corresponding to any chosen backend (except Zygote) must be loaded in advance.
+
+If no `adtype` is given, then Zygote.jl is used by default, unless at least one argument
+is of type `Duplicated` from Enzyme.jl, in which case Enzyme.jl is used.
See also [`withgradient`](@ref) to keep the value `f(args...)`.
# Examples
-```
+```julia-repl
julia> Flux.gradient(*, 2.0, 3.0, 5.0)
(15.0, 10.0, 6.0)
@@ -27,10 +36,29 @@ julia> Flux.gradient([7, 11], 0, 1) do x, y, d
end
([14.0, 22.0], 2.0, nothing)
```
+Specifying other AD backends:
+
+```julia-repl
+julia> using Mooncake
+
+julia> f(x) = sum(2 .* x)
+f (generic function with 1 method)
+
+julia> Flux.gradient(f, AutoMooncake(), [1.0, 2.0, 3.0])
+([2.0, 2.0, 2.0],)
+```
"""
+function gradient(f, adtype::ADTypes.AbstractADType, args...)
+ error("AD backend has to be loaded to use `gradient(f, AutoXXX(), args...)`.
+ Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`.
+ Supported backends are $SUPPORTED_AD_BACKENDS.")
+end
+
+
+# Default gradient using Zygote
function gradient(f, args...; zero::Bool=true)
for a in args
- a isa EnzymeCore.Duplicated && return _enzyme_gradient(f, map(_ensure_enzyme, args)...; zero)
+ a isa Union{EnzymeCore.Duplicated, EnzymeCore.Const} && return gradient(f, AutoEnzyme(), args...; zero)
end
for a in args
_ensure_noenzyme(a)
@@ -41,17 +69,9 @@ function gradient(f, args...; zero::Bool=true)
If you are writing new code, then Zygote over Zygote is heavily discouraged.
""")
end
- Zygote.gradient(f, args...)
+ return Zygote.gradient(f, args...)
end
-# Given one Duplicated, we wrap everything else in Const before calling Enzyme
-_ensure_enzyme(x::EnzymeCore.Duplicated) = x
-_ensure_enzyme(x::EnzymeCore.Const) = x
-_ensure_enzyme(x) = EnzymeCore.Const(x)
-_ensure_enzyme(x::EnzymeCore.Active) = throw(ArgumentError(
- "The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`."
-))
-
# Without any Duplicated, check for no stray Enzyme types before calling Zygote
_ensure_noenzyme(::EnzymeCore.Const) = throw(ArgumentError(
"The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
@@ -62,7 +82,7 @@ _ensure_noenzyme(::EnzymeCore.Active) = throw(ArgumentError(
_ensure_noenzyme(_) = nothing
"""
- gradient(f, args::Union{Const,Duplicated}...)
+ gradient(f, args::Union{Any,EnzymeCore.Duplicated}...)
This should return the same answer as `gradient(f, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
@@ -70,7 +90,7 @@ but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
This method is used when at least one argument is of type `Duplicated`,
-and all unspecified aguments are wrapped in `Const`.
+All non-duplicated arguments are treated as `Const`.
Note that Enzyme's `Active` is not supported.
Besides returning the gradient, this is also stored within the `Duplicated` object.
@@ -78,12 +98,9 @@ Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
which is zero'd befor use when calling `gradient`.
With the keyword `zero=false`, the new gradient will instead be added to what is already stored.
-!!! warning "Experimental"
- Enzyme support like this is new and somewhat experimental.
- This method was added in Flux 0.15.
+# Examples
-# Example
-```
+```julia-repl
julia> using Flux
julia> model = Chain(Dense([3.0;;]));
@@ -119,28 +136,26 @@ julia> Flux.gradient(dup_model, [1]; zero=false) do m, x # implict Const([1]),
((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing)
```
"""
-gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_gradient(f, args...; zero)
-
-gradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError(
- "The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
-))
-
-# FluxEnzymeExt defines more specific _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero)
-_enzyme_gradient(f, args...; zero) = throw(ArgumentError(
- "Methods like `gradient(f, x::Duplicated)` are only available when Enzyme is loaded."
-))
+gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = gradient(f, AutoEnzyme(), args...; zero)
"""
- withgradient(f, args...)
+ withgradient(f, [adtype,] args...)
Returns both the value of the function and the [`gradient`](@ref), as a named tuple.
-By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other methods become available.
+The optional argument `adtype` allows specifying the automatic differentiation backend
+among the supported ones: `AutoZygote`, `AutoEnzyme`, `AutoMooncake`, and `AutoFiniteDifferences`.
+The package corresponding to the chosen backend must be loaded in advance.
-# Example
+If no `adtype` is given, then Zygote.jl is used by default, unless at least one argument
+is of type `Duplicated` from Enzyme.jl, in which case Enzyme.jl is used.
-```
+Se also [`gradient`](@ref) to get just the gradient.
+
+# Examples
+
+```jldoctest
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
@@ -148,12 +163,12 @@ julia> ∇ == gradient(/, 1, 2)
true
```
-Allows you to capture auxillary outputs, in addition to the scalar
+`withgradient` allows you to capture auxillary outputs, in addition to the scalar
used by `gradient`. To do this, `f` must return a Tuple or NamedTuple.
Then it calculates `grad = gradient(first∘f, args...)
but returns the whole `val = f(args...)`:
-```jldoctest; setup=:(using Zygote)
+```jldoctest
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), z # here z is an auxillary output
@@ -165,10 +180,29 @@ julia> withgradient(3.0, 4.0) do x, y
end
(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875))
```
+
+Different AD backends can be specified:
+```julia-repl
+julia> using Mooncake
+
+julia> f(x) = sum(2 .* x)
+f (generic function with 1 method)
+
+julia> Flux.withgradient(f, AutoMooncake(), [1.0, 2.0, 3.0])
+(val = 12.0, grad = ([2.0, 2.0, 2.0],))
+```
"""
+function withgradient(f, adtype::ADTypes.AbstractADType, args...)
+ error("AD backend has to be loaded to use `withgradient(f, AutoXXX(), args...)`.
+ Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`.
+ Supported backends are $SUPPORTED_AD_BACKENDS.")
+end
+
+
+# Default withgradient using Zygote
function withgradient(f, args...; zero::Bool=true)
for a in args
- a isa EnzymeCore.Duplicated && return _enzyme_withgradient(f, map(_ensure_enzyme, args)...; zero)
+ a isa Union{EnzymeCore.Duplicated, EnzymeCore.Const} && return withgradient(f, AutoEnzyme(), args...; zero)
end
for a in args
_ensure_noenzyme(a)
@@ -179,22 +213,28 @@ function withgradient(f, args...; zero::Bool=true)
If you are writing new code, then Zygote over Zygote is heavily discouraged.
""")
end
- Zygote.withgradient(f, args...)
+ return Zygote.withgradient(f, args...)
+end
+
+## Zygote version, supporting aux output too.
+function withgradient(f::F, adtype::AutoZygote, x::Vararg{Any,N}) where {F,N}
+ return Zygote.withgradient(f, x...)
end
"""
- withgradient(f, args::Union{Const,Duplicated}...)
+ withgradient(f, args::Union{Any,EnzymeCore.Duplicated}...)
This should return the same answer as `withgradient(f, model, args...)`,
but it uses Enzyme.jl instead of Zygote.jl to compute the derivative.
Only available when Enzyme is loaded!
-!!! warning "Experimental"
- Enzyme support like this is new and somewhat experimental.
- This method was added in Flux 0.15.
+This method is used when at least one argument is of type `Duplicated`,
+All non-duplicated arguments will be differentiated as well.
+Mark them as `Const` to avoid this.
+Note that Enzyme's `Active` is not supported.
-# Example
+# Examples
```julia-repl
julia> using Flux, Enzyme
@@ -210,26 +250,18 @@ julia> Flux.withgradient(m -> m(3), model) # this uses Zygote
julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme
(val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
```
-
-The function `f` may return Tuple or NamedTuple, with the loss as the first element.
-The gradient is then `grad = gradient(first∘f, args...)`
-but the returned value is `val = f(args...)`:
-
-```julia-repl
-julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model))
-(val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
-
-julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model))
-(val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
-```
"""
-withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_withgradient(f, args...; zero)
+withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = withgradient(f, AutoEnzyme(), args...; zero)
-withgradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError(
- "The method `withgradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
-))
+## ADD BACK TO withgradient docstring above when AUX is SUPPORTED
+# The function `f` may return Tuple or NamedTuple, with the loss as the first element.
+# The gradient is then `grad = gradient(first∘f, args...)`
+# but the returned value is `val = f(args...)`:
-# FluxEnzymeExt defines more specific _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero)
-_enzyme_withgradient(f, args...; zero) = throw(ArgumentError(
- "Methods like `withgradient(f, x::Duplicated)` are only available when Enzyme is loaded."
-))
+# ```julia-repl
+# julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model))
+# (val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
+
+# julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model))
+# (val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),))
+# ```
diff --git a/test/Project.toml b/test/Project.toml
index cffbc43306..83882844d5 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -2,6 +2,7 @@
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -12,6 +13,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -21,9 +23,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
+[sources]
+Flux = {path = ".."}
+
[compat]
FiniteDifferences = "0.12"
Tracker = "0.2.33"
-
-[sources.Flux]
-path = ".."
diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl
index 8962c7bedb..6926213af4 100644
--- a/test/ext_amdgpu/basic.jl
+++ b/test/ext_amdgpu/basic.jl
@@ -19,7 +19,7 @@ end
@testset "Chain of Dense layers" begin
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
x = rand(Float32, 10, 10)
- test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
+ test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
@testset "Convolution" begin
@@ -75,14 +75,14 @@ end
@testset "Chain(Conv)" begin
m = Chain(Conv((3, 3), 3 => 3))
x = rand(Float32, 5, 5, 3, 2)
- test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false)
+ test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
md = m |> gpu |> cpu
@test md[1].weight ≈ m[1].weight atol=1f-3
m = Chain(ConvTranspose((3, 3), 3 => 3))
x = rand(Float32, 5, 5, 3, 2)
- test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false)
+ test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
md = m |> gpu |> cpu
@test md[1].weight ≈ m[1].weight atol=1f-3
@@ -91,7 +91,7 @@ end
@testset "Cross-correlation" begin
m = CrossCor((2, 2), 3 => 4)
x = rand(Float32, 5, 5, 3, 2)
- test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
+ test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
@testset "Restructure" begin
@@ -131,7 +131,7 @@ end
bn = BatchNorm(3, σ)
for nd in 1:3
x = rand(Float32, fill(2, nd - 1)..., 3, 4)
- test_gradients(bn, x; test_gpu=true, compare_finite_diff=false)
+ test_gradients(bn, x; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
end
diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl
index 35fa983806..cdbd67c33e 100644
--- a/test/ext_common/recurrent_gpu_ad.jl
+++ b/test/ext_common/recurrent_gpu_ad.jl
@@ -15,12 +15,12 @@ end
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
# Single Step
- @test test_gradients(r, x[1], h; test_gpu=true,
- compare_finite_diff=false,
+ @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss=cell_loss) broken = :rnncell_single ∈ BROKEN_TESTS
# Multiple Steps
- @test test_gradients(r, x, h; test_gpu=true,
- compare_finite_diff=false,
+ @test test_gradients(r, x, h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss=recurrent_cell_loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
end
@@ -37,9 +37,11 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
- @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :rnn_nobatch ∈ BROKEN_TESTS
+ @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :rnn_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, batch_size)
- @test test_gradients(model, x, test_gpu=true, compare_finite_diff=false) broken = :rnn ∈ BROKEN_TESTS
+ @test test_gradients(model, x, test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :rnn ∈ BROKEN_TESTS
end
@testset "LSTMCell" begin
@@ -49,11 +51,12 @@ end
h = zeros(Float32, d_out)
c = zeros(Float32, d_out)
# Single Step
- @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
+ @test test_gradients(cell, x[1], (h, c); test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss = cell_loss) broken = :lstmcell_single ∈ BROKEN_TESTS
# Multiple Steps
- @test test_gradients(cell, x, (h, c); test_gpu=true,
- compare_finite_diff = false,
+ @test test_gradients(cell, x, (h, c); test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss = recurrent_cell_loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
end
@@ -71,11 +74,11 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
- @test test_gradients(model, x_nobatch; test_gpu=true,
- compare_finite_diff=false) broken = :lstm_nobatch ∈ BROKEN_TESTS
+ @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :lstm_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
- @test test_gradients(model, x; test_gpu=true,
- compare_finite_diff=false) broken = :lstm ∈ BROKEN_TESTS
+ @test test_gradients(model, x; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :lstm ∈ BROKEN_TESTS
end
@testset "GRUCell" begin
@@ -83,11 +86,11 @@ end
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
- @test test_gradients(r, x[1], h; test_gpu=true,
- compare_finite_diff=false,
+ @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss = cell_loss) broken = :grucell_single ∈ BROKEN_TESTS
- @test test_gradients(r, x, h; test_gpu=true,
- compare_finite_diff = false,
+ @test test_gradients(r, x, h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS
end
@@ -104,11 +107,11 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
- @test test_gradients(model, x_nobatch; test_gpu=true,
- compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS
+ @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :gru_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
- @test test_gradients(model, x; test_gpu=true,
- compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS
+ @test test_gradients(model, x; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :gru ∈ BROKEN_TESTS
end
@testset "GRUv3Cell GPU AD" begin
@@ -116,11 +119,11 @@ end
r = GRUv3Cell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
- @test test_gradients(r, x[1], h; test_gpu=true,
- compare_finite_diff=false,
+ @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss=cell_loss) broken = :gruv3cell_single ∈ BROKEN_TESTS
- @test test_gradients(r, x, h; test_gpu=true,
- compare_finite_diff=false,
+ @test test_gradients(r, x, h; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing,
loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
end
@@ -137,9 +140,9 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
- @test test_gradients(model, x_nobatch; test_gpu=true,
- compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS
+ @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :gruv3_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
- @test test_gradients(model, x; test_gpu=true,
- compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS
+ @test test_gradients(model, x; test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing) broken = :gruv3 ∈ BROKEN_TESTS
end
diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl
index b7b456bcf1..f6b44e74e1 100644
--- a/test/ext_cuda/layers.jl
+++ b/test/ext_cuda/layers.jl
@@ -13,19 +13,14 @@ end
const ACTIVATIONS = [identity, tanh]
function gpu_gradtest(name::String, layers::Vector, x_cpu, args...;
- test_mode=false, test_grad_x=true,
+ test_mode=false,
atol=1e-4, rtol=1e-4)
@testset "$name GPU grad tests" begin
for layer in layers
@testset "$layer Layer GPU grad test" begin
-
- # compute output and grad of parameters
l_cpu = layer(args...)
- if test_mode
- testmode!(l_cpu)
- end
-
- test_gradients(l_cpu, x_cpu; test_gpu=true, compare_finite_diff=false, test_grad_x, atol, rtol)
+ test_gradients(l_cpu, x_cpu; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing,
+ atol, rtol, test_mode)
end
end
end
@@ -90,19 +85,22 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
embedding = [Flux.Embedding]
-gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding integer index", embedding, 1, 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2, test_grad_x=false)
-gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2, test_grad_x=false)
+gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
+gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
+gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
+gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
+gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
+gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
+gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
@testset "function layers" begin
x = rand(Float32, 3, 3)
- test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, compare_finite_diff=false)
- test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, compare_finite_diff=false)
- test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, compare_finite_diff=false)
+ test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing)
+ test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing)
+ test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, test_cpu=false,
+ reference=AutoZygote(), compare=nothing)
end
@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
@@ -183,7 +181,7 @@ end
@test size(b(x, y)) == (3,9)
@test sum(abs2, b(x, y)) ≈ 0f0
test_gradients(b |> cpu, x |> cpu, y |> cpu,
- test_gpu=true, compare_finite_diff=false, loss=(m, x, y) -> mean(abs2, m(x, y)))
+ test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
@testset "Two-streams Bilinear" begin
@@ -193,7 +191,7 @@ end
@test size(b(x, y)) == (3,9)
@test sum(abs2, b(x, y)) ≈ 0f0
test_gradients(b |> cpu, x |> cpu, y |> cpu,
- test_gpu=true, compare_finite_diff=false, loss=(m, x, y) -> mean(abs2, m(x, y)))
+ test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
@testset "Parallel" begin
@@ -213,7 +211,7 @@ end
@testset "gradient" begin
layer_cpu = Parallel(+, x -> zero(x), identity)
test_gradients(layer_cpu, randn(2, 2, 2, 2),
- test_gpu=true, compare_finite_diff=false, loss=(m, x) -> mean(abs2, m(x)))
+ test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
end
@@ -294,5 +292,5 @@ end
return sum(y.^2) + sum(α.^2)
end
test_gradients(mha_cpu, x_cpu; loss,
- test_gpu=true, compare_finite_diff=false)
+ test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
end
diff --git a/test/ext_cuda/losses.jl b/test/ext_cuda/losses.jl
index 11e14981b7..514b71634c 100644
--- a/test/ext_cuda/losses.jl
+++ b/test/ext_cuda/losses.jl
@@ -30,7 +30,7 @@ y = [1 0 0 0 1
y = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4)
@test loss(x, y) ≈ loss(gpu(x), gpu(y))
- test_gradients(loss, x, y, test_gpu=true, test_grad_f=false, compare_finite_diff=false)
+ test_gradients(loss, x, y, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
# Float16 tests
@test loss(f16(x), f16(y)) ≈ loss(gpu(f16(x)), gpu(f16(y)))
diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl
index e93568afe8..68645445f7 100644
--- a/test/ext_enzyme/enzyme.jl
+++ b/test/ext_enzyme/enzyme.jl
@@ -1,52 +1,9 @@
# ENZYME CPU TESTS
-@testset "Models" begin
- function loss(model, x)
- mean(model(x))
- end
-
- models_xs = [
- (Dense(2=>4), randn(Float32, 2), "Dense"),
- (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
- (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
- (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
- (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
- (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
- (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
- (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
- (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
- (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
- (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
- (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
- (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
- (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
- ]
-
- for (model, x, name) in models_xs
+@testset "enzyme gradients" begin
+ for (model, x, name) in TEST_MODELS
@testset "Enzyme grad check $name" begin
- println("testing $name with Enzyme")
- @test test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
- end
- end
-end
-
-@testset "Recurrent Layers" begin
- function loss(model, x)
- mean(model(x))
- end
-
- models_xs = [
- (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
- (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
- (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
- (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
- (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
- ]
-
- for (model, x, name) in models_xs
- @testset "check grad $name" begin
- println("testing $name")
- test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
+ @test test_gradients(model, x; reference=AutoZygote(), compare=AutoEnzyme())
end
end
end
@@ -61,25 +18,26 @@ end
@test g1.bias == [1, 1]
@test m1.dval.bias == [1, 1]
- g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, [1,2,3f0])
+ g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Const([1,2,3f0]))
@test g2.val ≈ sum(m1([1,2,3f0]))
@test g2.grad[1].weight ≈ [1 2 3; 1 2 3]
- @test g2.grad[2] === nothing # implicitly Const
+ @test g2.grad[2] === nothing
- g3 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x
- z = 1 ./ x
- sum(z), z # here z is an auxillary output
- end
- @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625]
- @test g3.val[1] ≈ 1.75
- @test g3.val[2] ≈ [1.0, 0.5, 0.25]
- g4 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x
- z = 1 ./ x
- (loss=sum(z), aux=string(z))
- end
- @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625]
- @test g4.val.loss ≈ 1.75
- @test g4.val.aux == "[1.0, 0.5, 0.25]"
+ ## Auxillary outputs not supported at the moment
+ # g3 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x
+ # z = 1 ./ x
+ # sum(z), z # here z is an auxillary output
+ # end
+ # @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625]
+ # @test g3.val[1] ≈ 1.75
+ # @test g3.val[2] ≈ [1.0, 0.5, 0.25]
+ # g4 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x
+ # z = 1 ./ x
+ # (loss=sum(z), aux=string(z))
+ # end
+ # @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625]
+ # @test g4.val.loss ≈ 1.75
+ # @test g4.val.aux == "[1.0, 0.5, 0.25]"
# setup understands Duplicated:
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)
@@ -93,11 +51,11 @@ end
m1.val.weight .= 0
@test Flux.loadmodel!(m1, oldpar).val.weight ≈ oldpar.weight
- # At least one Duplicated is required:
- @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val))
- @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
- @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val))
- @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
+ # Only Const args are supported
+ @test Flux.gradient(m -> sum(m.bias), Const(m1.val))[1] === nothing
+ @test Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) isa Tuple{Nothing,Vector{Float32}}
+ @test Flux.withgradient(m -> sum(m.bias), Const(m1.val)).grad[1] === nothing
+ @test Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]).grad isa Tuple{Nothing,Vector{Float32}}
# Active is disallowed:
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0))
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0))
@@ -115,7 +73,6 @@ end
@test Flux.gradient(sum ∘ LayerNorm(3), z)[1] ≈ [0.0, 0.0, 0.0]
@test Flux.gradient(|>, z, _duplicated(sum ∘ LayerNorm(3)))[1] ≈ [0.0, 0.0, 0.0]
@test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing
-
- @test_broken Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
- @test_broken Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0]
+ @test Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0]
+ @test Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0]
end
diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl
index 9febd8e455..b041a9272f 100644
--- a/test/ext_metal/basic.jl
+++ b/test/ext_metal/basic.jl
@@ -21,5 +21,15 @@ end
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
x = rand(Float32, 10, 10)
@test (m|>gpu)(x|>gpu) isa MtlArray{Float32, 2}
- test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
+ test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing)
+end
+
+@testset "gradients" begin
+ broken_models = ["Conv", "Chain(Conv, Conv)", "Chain(Conv, MeanPool)", "ConvTranspose","Bilinear","MultiHeadAttention"]
+ # Bilinear and MultiHeadAttention will be fixed by https://github.com/FluxML/NNlib.jl/pull/614
+ for (model, x, name) in TEST_MODELS
+ @testset "Zygote grad check $name" begin
+ @test test_gradients(model, x; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) broken=(name ∈ broken_models)
+ end
+ end
end
diff --git a/test/ext_mooncake.jl b/test/ext_mooncake.jl
new file mode 100644
index 0000000000..f60633e981
--- /dev/null
+++ b/test/ext_mooncake.jl
@@ -0,0 +1,7 @@
+@testset "mooncake gradient" begin
+ for (model, x, name) in TEST_MODELS
+ @testset "grad check $name" begin
+ @test test_gradients(model, x; reference=AutoZygote(), compare=AutoMooncake())
+ end
+ end
+end
diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl
index e6fd9f48c9..820c0dd579 100644
--- a/test/ext_reactant/reactant.jl
+++ b/test/ext_reactant/reactant.jl
@@ -1,78 +1,8 @@
-function scalarfirst(x)
- Reactant.@allowscalar first(x)
-end
-
@testset "Reactant Models" begin
- function loss(model, x)
- mean(model(x))
- end
-
- models_xs = [
- (Dense(2=>4), randn(Float32, 2), "Dense"),
-
- (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
-
- (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
-
- (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
-
- (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
-
- (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
-
- (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
-
- (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
-
- (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
-
- (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
-
- (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
-
- (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
- ]
-
- for (model, x, name) in models_xs
- @testset "Enzyme grad check $name" begin
- println("testing $name with Reactant")
- test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true)
- end
- end
-
- models_xs = [
- (LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar
-
- (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar
- ]
-
- Reactant.allowscalar(true)
- for (model, x, name) in models_xs
- @testset "Enzyme grad check $name" begin
- println("testing $name with Reactant")
- test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true)
- end
- end
- Reactant.allowscalar(false)
-end
-
-@testset "Reactant Recurrent Layers" begin
- function loss(model, x)
- mean(model(x))
- end
-
- models_xs = [
- (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
- (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
- (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
- (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
- (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
- ]
-
- for (model, x, name) in models_xs
- @testset "check grad $name" begin
- println("testing $name with Reactant")
- test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true)
+ broken_models = ()
+ for (model, x, name) in TEST_MODELS
+ @testset "Reactant grad check $name" begin
+ @test test_gradients(model, x; reference=AutoZygote(), test_reactant=true, test_cpu=false) broken=(name ∈ broken_models)
end
end
end
diff --git a/test/ext_reactant/test_utils_reactant.jl b/test/ext_reactant/test_utils_reactant.jl
index 6b3cec53d8..60e879d2bb 100644
--- a/test/ext_reactant/test_utils_reactant.jl
+++ b/test/ext_reactant/test_utils_reactant.jl
@@ -2,7 +2,7 @@
# because Reactant is only optionally loaded and the macros fail when it is not loaded.
function reactant_withgradient(f, x...)
- y, g = Reactant.@jit enzyme_withgradient(f, x...)
+ y, g = Reactant.@jit Flux.withgradient(f, AutoEnzyme(), x...)
return y, g
end
diff --git a/test/runtests.jl b/test/runtests.jl
index 301d46dc1f..f4a0b7facf 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -6,9 +6,10 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector,
using Flux.Losses: xlogx, xlogy
using Flux.Losses
using ForwardDiff: ForwardDiff
-using Functors: Functors, fmapstructure_with_path
+using Functors: Functors, fmapstructure_with_path, fmap
using IterTools: ncycle
using LinearAlgebra
+using Mooncake: Mooncake
using MLUtils: MLUtils, batch, unstack, unsqueeze,
unbatch, getobs, numobs, flatten, DataLoader
using Optimisers: Optimisers
@@ -18,8 +19,6 @@ using SparseArrays
using Statistics
using Test
using Zygote: Zygote
-# const gradient = Flux.gradient # both Flux & Zygote export this on 0.15
-# const withgradient = Flux.withgradient
## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
@@ -104,6 +103,10 @@ end
include("functors.jl")
end
+ @testset "mooncake" begin
+ include("ext_mooncake.jl")
+ end
+
@testset "deprecations" begin
include("deprecations.jl")
end
diff --git a/test/test_utils.jl b/test/test_utils.jl
index 3576d314e2..d4cc62e9f0 100644
--- a/test/test_utils.jl
+++ b/test/test_utils.jl
@@ -12,27 +12,27 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.siamese_contrastive_loss]
-function finitediff_withgradient(f, x...)
- y = f(x...)
- # We set a range to avoid domain errors
- fdm = FiniteDifferences.central_fdm(5, 1, max_range=1e-2)
- return y, FiniteDifferences.grad(fdm, f, x...)
-end
-
-function enzyme_withgradient(f, x...)
- args = []
- for x in x
- if x isa Number
- push!(args, Enzyme.Active(x))
- else
- push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
- end
- end
- ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
- ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
- g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
- return ret[2], g
-end
+const TEST_MODELS = [
+ (Dense(2=>4), randn(Float32, 2), "Dense"),
+ (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"),
+ (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
+ (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
+ (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
+ (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
+ (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
+ (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
+ (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
+ (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
+ (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
+ (LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
+ (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
+ (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
+ (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
+ (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
+ (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
+ (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
+ (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
+]
function _contains_no_numerical(kp, x)
count = 0
@@ -46,132 +46,108 @@ function _contains_no_numerical(kp, x)
end
function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
+ # Since Zygote could use nothing for an entire subtree, we prune the
+ # the tree using _contains_no_numerical
fmapstructure_with_path(a, b, exclude=_contains_no_numerical) do kp, x, y
- if y isa Nothing
- return
- end
# @show kp
- if x isa AbstractArray
- @test x ≈ y rtol=rtol atol=atol
- elseif x isa Number
+ if x isa AbstractArray{<:AbstractFloat}
@test x ≈ y rtol=rtol atol=atol
end
end
+ return true
+end
+
+function _contains_no_numerical(kp, x)
+ count = 0
+ fmap(x) do y
+ if y isa AbstractArray{<:AbstractFloat}
+ count += 1
+ end
+ return y
+ end
+ return count == 0
end
-# By default, this computes the gradients on cpu using the default AD (Zygote)
-# and compares them with finite differences.
-# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth
-# and test other scenarios.
+_default_fdm() = FiniteDifferences.central_fdm(5, 1, max_range=1e-2)
+
+"""
+Compare the `reference` and `compare` AD backends on the gradients of `f` at `xs...`.
+The loss function can be customized (default is mean over outputs).
+
+- If `test_gpu` is true, the `compare` backend is tested on GPU.
+- If `test_cpu` is true, the `compare` backend is tested on CPU.
+- If `test_reactant` is true, the Enzyme backend is tested with Reactant.
+ Depending on the platform, this may run on CPU or GPU.
+"""
function test_gradients(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_gpu = false,
+ test_cpu = true,
test_reactant = false,
- test_enzyme = false,
- test_grad_f = true,
- test_grad_x = true,
- compare_finite_diff = true,
+ reference = AutoFiniteDifferences(; fdm = _default_fdm()),
+ compare = AutoZygote(),
loss = (f, xs...) -> mean(f(xs...)),
+ test_mode = false,
)
- if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant
- error("You should either compare numerical gradients methods or CPU vs GPU.")
- end
+ @assert reference !== nothing "reference AD backend must be provided"
+ @assert compare !== nothing || test_gpu "compare AD backend must be provided if test_gpu=false"
+ compare = compare === nothing ? reference : compare
- Flux.trainmode!(f) # for layers like BatchNorm
+ if test_mode
+ Flux.testmode!(f)
+ else
+ Flux.trainmode!(f)
+ end
- ## Let's make sure first that the forward pass works.
- l = loss(f, xs...)
- @test l isa Number
+ cpu_dev = cpu_device()
+
if test_gpu
gpu_dev = gpu_device(force=true)
cpu_dev = cpu_device()
xs_gpu = xs |> gpu_dev
f_gpu = f |> gpu_dev
- l_gpu = loss(f_gpu, xs_gpu...)
- @test l_gpu isa Number
end
-
+
if test_reactant
reactant_dev = MLDataDevices.reactant_device(force=true)
- cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
- l_re = reactant_loss(loss, f_re, xs_re...)
- @test l ≈ l_re rtol=rtol atol=atol
end
- if test_grad_x
- # Zygote gradient with respect to input.
- y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...)
-
- if compare_finite_diff
- # Cast to Float64 to avoid precision issues.
- f64 = f |> Flux.f64
- xs64 = xs .|> Flux.f64
- y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, xs...), xs64...)
- @test y ≈ y_fd rtol=rtol atol=atol
- check_equal_leaves(g, g_fd; rtol, atol)
- end
-
- if test_enzyme
- y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
- @test y ≈ y_ez rtol=rtol atol=atol
- check_equal_leaves(g, g_ez; rtol, atol)
- end
+ ## Let's make sure first that the forward pass works.
+ l = loss(f, xs...)
+ @test l isa Number
- if test_gpu
- # Zygote gradient with respect to input on GPU.
- y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...)
- @test get_device(g_gpu) == get_device(xs_gpu)
- @test y_gpu ≈ y rtol=rtol atol=atol
- check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
- end
+ # Compute reference gradients in f64 precision
+ y, gs = Flux.withgradient(loss, reference, Flux.f64(f), Flux.f64(xs)...)
+ @test l ≈ y rtol=rtol atol=atol
- if test_reactant
- # Enzyme gradient with respect to input on Reactant.
- y_re, g_re = reactant_withgradient(Base.Fix1(loss, f_re), xs_re...)
- @test y ≈ y_re rtol=rtol atol=atol
- check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
- end
+ if test_cpu
+ y2, gs2 = Flux.withgradient(loss, compare, f, xs...)
+ @test l ≈ y2 rtol=rtol atol=atol
+ check_equal_leaves(gs, gs2; rtol, atol)
end
- if test_grad_f
- # Zygote gradient with respect to f.
- y, g = Zygote.withgradient(f -> loss(f, xs...), f)
-
- if compare_finite_diff
- # Cast to Float64 to avoid precision issues.
- f64 = f |> Flux.f64
- ps, re = Flux.destructure(f64)
- y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps), xs...), ps)
- g_fd = (re(g_fd[1]),)
- @test y ≈ y_fd rtol=rtol atol=atol
- check_equal_leaves(g, g_fd; rtol, atol)
- end
+ if test_gpu
+ l_gpu = loss(f_gpu, xs_gpu...)
+ @test l_gpu isa Number
- if test_enzyme
- y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
- @test y ≈ y_ez rtol=rtol atol=atol
- check_equal_leaves(g, g_ez; rtol, atol)
- end
+ y_gpu, gs_gpu = Flux.withgradient(loss, compare, f_gpu, xs_gpu...)
+ @test l_gpu ≈ y_gpu rtol=rtol atol=atol
+ check_equal_leaves(gs, gs_gpu |> cpu_dev; rtol, atol)
+ end
- if test_gpu
- # Zygote gradient with respect to f on GPU.
- y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu)
- # @test get_device(g_gpu) == get_device(xs_gpu)
- @test y_gpu ≈ y rtol=rtol atol=atol
- check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol)
- end
+ if test_reactant
+ l_re = reactant_loss(loss, f_re, xs_re...)
+ @test l ≈ l_re rtol=rtol atol=atol
- if test_reactant
- # Enzyme gradient with respect to input on Reactant.
- y_re, g_re = reactant_withgradient(Base.Fix2(loss, xs_re[1]), f_re)
- @test y ≈ y_re rtol=rtol atol=atol
- check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
- end
+ y_re, g_re = reactant_withgradient(loss, f_re, xs_re...)
+ @test y ≈ y_re rtol=rtol atol=atol
+ check_equal_leaves(gs, g_re |> cpu_dev; rtol, atol)
end
+
return true
end