diff --git a/.gitignore b/.gitignore index c289756be2..dd9aa9ad8e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ docs/build/ docs/site/ deps .vscode -Manifest.toml +Manifest*toml LocalPreferences.toml .DS_Store docs/mymodel.bson diff --git a/NEWS.md b/NEWS.md index 4e2a3316c1..c9056742df 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,11 +2,71 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. -## v0.16.1 (25 December 2025) +## v0.16.8 (January 2025) -The default `init_score` value for `early_stopping` has been set to `Inf` (instead of `0`) in order to prevent unexpected behavior if the defaults were not modified. Documentation has been updated to explain that, if the user needs to track a metric where improvement is shown by increasing values, then the `init_score` needs to be adjusted, for example, to `-Inf`). Tests corresponding to `early_stopping` have been reorganized and extended to be more detailed and illustrative of `early_stopping`'s behavior. +This release includes the following changes: +- Added support in `Flux.gradient` and `Flux.withgradient` to alternative AD backends such as `AutoEnzyme()` and `AutoMooncake()`. +- The default `init_score` value for `early_stopping` has been set to `Inf` (instead of `0`) in order to prevent unexpected behavior if the defaults were not modified. -## v0.16.0 (15 December 2025) +## v0.16.7 (10 December 2025) + +This patch release includes: + +* Minor documentation fixes and housekeeping commits. +* Compatibility updates for downstream packages. + + +## v0.16.6 (8 December 2025) + +This patch release includes: + +* Minor dependency bumps and CI updates. +* Preparatory changes ahead of v0.16.7. + +## v0.16.5 (23 July 2025) + +This release includes: + +* Fix typos in legacy tutorials documentation.([GitHub][2]) +* Bump compatibility for `AMDGPU` in weak dependencies.([GitHub][2]) +* **Fix** for `unsafe_free!` failure with certain `CuArray` configurations.([GitHub][2]) + +## v0.16.4 (2 June 2025) + +This release includes: + +* Fix missing imports in `FluxMPIExt`.([GitHub][1]) +* Add shape validation for convolution weight tensors.([GitHub][1]) +* Disable and fix intermittent Reactant tests.([GitHub][1]) +* Fix recurrent docstrings and pooling layer loading.([GitHub][1]) +* Small test updates and miscellaneous doc fixes.([GitHub][1]) + +## v0.16.3 (6 February 2025) + +This release includes: + +* **Fix** for `cpu(dataloader)` behavior.([GitHub][1]) +* Addressed data loading and preprocessing pipeline issues.([GitHub][1]) +* Resolved “Infinite time of gradient” edge case.([GitHub][1]) + +## v0.16.2 (21 January 2025) + +This release includes: + +* Updated dependencies and bumped to v0.16.1 as a base.([GitHub][1]) +* **Fixes** around new gradients, precompilation on Julia 1.12, and export issues.([GitHub][1]) + +## v0.16.1 (13 January 2025) + +This release includes: + +* Added references to recurrent layers in `ecosystem.md`.([GitHub][1]) +* Fixed typo in recurrence documentation.([GitHub][1]) +* Added “return state” option to recurrent layers.([GitHub][1]) +* Updated schedulers docs, collapsed docstrings in layers docs.([GitHub][1]) +* Test fixes for Enzyme and Reactant forward/reverse passes.([GitHub][1]) + +## v0.16.0 (15 December 2024) This release has a single **breaking change**: - The recurrent cells `RNNCell`, `LSTMCell`, and `GRUCell` forward has been changed to diff --git a/Project.toml b/Project.toml index 23fe826646..0445c4d43f 100644 --- a/Project.toml +++ b/Project.toml @@ -2,10 +2,8 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" version = "0.16.7" -[workspace] -projects = ["test", "docs"] - [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -33,7 +31,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -42,10 +42,13 @@ FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxEnzymeExt = "Enzyme" +FluxFiniteDifferencesExt = "FiniteDifferences" FluxMPIExt = "MPI" FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] +FluxMooncakeExt = "Mooncake" [compat] +ADTypes = "1" AMDGPU = "1, 2" Adapt = "4" CUDA = "5" @@ -53,12 +56,14 @@ ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.13" EnzymeCore = "0.7.7, 0.8.4" +FiniteDifferences = "0.12" Functors = "0.5" MLCore = "1.0.0" MLDataDevices = "1.4.2" MLUtils = "0.4" MPI = "0.20.19" MacroTools = "0.5" +Mooncake = "0.4" NCCL = "0.1.1" NNlib = "0.9.22" OneHotArrays = "0.2.4" @@ -72,3 +77,6 @@ Statistics = "1" Zygote = "0.6.67, 0.7" cuDNN = "1" julia = "1.10" + +[workspace] +projects = ["test", "docs"] diff --git a/docs/Project.toml b/docs/Project.toml index 0acc886a07..e9675ae73e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,6 +10,7 @@ MLCore = "c2834f40-e789-41da-a90e-33b280584a8c" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -17,5 +18,8 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +Flux = {path = ".."} + [compat] Documenter = "1.3" diff --git a/docs/make.jl b/docs/make.jl index 75f30494eb..f8d7d6cfb0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,7 @@ -using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, +using Documenter +using Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, - DataFrames, JLD2, MLDataDevices, MLCore + DataFrames, JLD2, MLDataDevices, MLCore, Mooncake using MLCore: numobs, getobs, getobs! ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -21,7 +22,8 @@ makedocs( sidebar_sitename = false, analytics = "UA-36890222-9", assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true" + prettyurls = get(ENV, "CI", nothing) == "true", + size_threshold=1_000_000, ), pages = [ "Welcome" => "index.md", @@ -50,8 +52,7 @@ makedocs( "Shape Inference" => "reference/outputsize.md", "Flat vs. Nested" => "reference/destructure.md", "Callback Helpers" => "reference/training/callbacks.md", - "Gradients -- Zygote.jl" => "reference/training/zygote.md", - "Gradients -- Enzyme.jl" => "reference/training/enzyme.md", + "Gradients" => "reference/training/gradients.md", "Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md", "Batching Data -- MLUtils.jl" => "reference/data/mlutils.md", "OneHotArrays.jl" => "reference/data/onehot.md", diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md index 27a958074c..9750c034d9 100644 --- a/docs/src/guide/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -181,9 +181,7 @@ These matching nested structures are at the core of how Flux works. This method of `gradient` takes a zero-argument function, which only *implicitly* depends on `θ`. -```@raw html -

 Zygote.jl

-``` +## Automatic Differentiation Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl). Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)` @@ -198,7 +196,7 @@ Flux can also be used with other automatic differentiation (AD) packages. It was originally written using [Tracker](https://github.com/FluxML/Tracker.jl), a more traditional operator-overloading approach. The future might be [Enzyme](https://github.com/EnzymeAD/Enzyme.jl), and Flux now builds in an easy way to use this instead, turned on by wrapping the model in `Duplicated`. (For details, see the [Enzyme page](@ref autodiff-enzyme) in the manual.) -```julia +```julia-repl julia> using Enzyme: Const, Duplicated julia> grad3e = Flux.gradient((x,p) -> p(x), Const(5.0), Duplicated(poly3s)) @@ -210,13 +208,33 @@ Here, this is because `Const(5.0)` is explicitly constant. Below, we will see an example where `nothing` shows up because the model struct has fields containing things other than parameters, such as an activation function. (It also adopts the convention that `gradient(f, x, y)` returns a tuple `(∂f/∂x, ∂f/∂y)`, without a "`∂f/∂f`" term for the function. This is why we had to write `gradient(|>, 5, poly4)` above, not just `gradient(poly4, 5)`.) -Finally, the function [`withgradient`](@ref) works the same way, but also returns the value of the function: +The function [`withgradient`](@ref) works the same way, but also returns the value of the function: ```jldoctest poly julia> Flux.withgradient((x,p) -> p(x), 5.0, poly3s) (val = 17.5, grad = (2.0, (θ3 = [1.0, 5.0, 25.0],))) ``` +One can also directly specify which AD backend to use, by passing an adtype among the supported ones +(`AutoMooncake, AutoEnzyme, AutoZygote, AutoFiniteDifferences`) as the second argument. +The corresponding AD package has to be loaded first. + +Here is an example using [Mooncake](https://github.com/chalk-lab/Mooncake.jl): +```jldoctest poly +julia> using Mooncake + +julia> Flux.withgradient((x,p) -> p(x), AutoMooncake(), 5.0, poly3s) +(val = 17.5, grad = (2.0, Poly3{Vector{Float64}}([1.0, 5.0, 25.0]))) +``` + +and here is the same example using Enzyme: +```julia-repl +julia> using Enzyme + +julia> Flux.withgradient((x,p) -> p(x), AutoEnzyme(), 5.0, poly3s) +(val = 17.5, grad = (2.0, Poly3{Vector{Float64}}([1.0, 5.0, 25.0]))) +``` + ## Simple Neural Networks The polynomial functions above send a number `x` to another a number `y`. diff --git a/docs/src/reference/data/mlutils.md b/docs/src/reference/data/mlutils.md index 3d4c56f0d0..12dbd1c00f 100644 --- a/docs/src/reference/data/mlutils.md +++ b/docs/src/reference/data/mlutils.md @@ -25,6 +25,7 @@ these functions help create inputs for your models or batch your dataset. MLUtils.batch MLUtils.batchsize MLUtils.batchseq +MLUtils.batch_sequence MLUtils.BatchView MLUtils.chunk MLUtils.eachobs diff --git a/docs/src/reference/training/enzyme.md b/docs/src/reference/training/gradients.md similarity index 66% rename from docs/src/reference/training/enzyme.md rename to docs/src/reference/training/gradients.md index 77875c9da3..bdda5535d9 100644 --- a/docs/src/reference/training/enzyme.md +++ b/docs/src/reference/training/gradients.md @@ -1,5 +1,64 @@ +```@meta +CollapsedDocStrings = true +``` + +# Automatic Differentiation in Flux + +Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model. +Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). -# [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme) +Flux also has support for Enzyme.jl, documented [below](@ref autodiff-enzyme) and for Mooncake.jl. + + +## Generic Gradient Interface + +```@docs +Flux.gradient(f, adtype::AbstractADType, args::Any...) +Flux.withgradient(f, adtype::AbstractADType, args::Any...) +``` + +## [Automatic Differentiation using Zygote.jl](@id autodiff-zygote) + +The default AD backend in Flux is Zygote. Besides gradient calculation, Zygote also supports +higher-order derivatives, Jacobians, Hessians, and pullbacks. + +```@docs +Zygote.jacobian(f, args...) +Zygote.withjacobian(f, args...) +Zygote.hessian +Zygote.hessian_reverse +Zygote.diaghessian +Zygote.pullback +``` + +## ChainRules for Zygote + +Zygote uses [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) to define how to differentiate functions. + +Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. +This can be done using the following methods: + +```@docs +ChainRulesCore.ignore_derivatives +ChainRulesCore.@non_differentiable +``` + +To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works. + +```@docs +ChainRulesCore.rrule +ChainRulesCore.frule +ChainRulesCore.@scalar_rule +ChainRulesCore.NoTangent +ChainRulesCore.ZeroTangent +ChainRulesCore.RuleConfig +ChainRulesCore.Tangent +ChainRulesCore.canonicalize +``` + +Gradient customization for other AD packages such as Enzyme and Mooncake has to be done according to their own documentation. + +## [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme) [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation. Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`. @@ -71,25 +130,16 @@ true Note that what `Enzyme.gradient` returns is an object like `deepcopy(model)` of the same type, `grads_e[1] isa Chain`. But its fields contain the same gradient. -There is also a method of `train!` which similarly takes `Duplicated(model)`: - -```julia-repl -julia> opt_state = Flux.setup(Adam(0), model); - -julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state) -``` - -## Second-order AD - -If you calculate a gradient within the loss function, then training will involve 2nd derivatives. -While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice. - -## Listing ```@docs Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...) -Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt) ``` Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/). + + +## Second-order AD + +If you calculate a gradient within the loss function, then training will involve 2nd derivatives. +While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice. diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md index a9ee35e4c0..6c825709e5 100644 --- a/docs/src/reference/training/reference.md +++ b/docs/src/reference/training/reference.md @@ -28,6 +28,19 @@ Optimisers.setup To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) and follow its setup instructions. + +There is also a method of `train!` which similarly takes `Duplicated(model)` and uses Enzyme.jl for differentiation (see (@ref autodiff-enzyme)): +```julia-repl +julia> opt_state = Flux.setup(Adam(0), model); + +julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state) +``` + +```@docs +Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt) +``` + + ## Optimisation Modifiers The state returned by `setup` can be modified to temporarily prevent training of diff --git a/docs/src/reference/training/zygote.md b/docs/src/reference/training/zygote.md deleted file mode 100644 index ddf65917a1..0000000000 --- a/docs/src/reference/training/zygote.md +++ /dev/null @@ -1,48 +0,0 @@ -```@meta -CollapsedDocStrings = true -``` - -# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote) - -Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model. -Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). - -Flux also has support for Enzyme.jl, documented [on its own page](@ref autodiff-enzyme). - -## Explicit style - -The preferred way of using Zygote, and the only way of using most other AD packages, -is to explicitly provide a function and its arguments. - -```@docs -Zygote.gradient(f, args...) -Zygote.withgradient(f, args...) -Zygote.jacobian(f, args...) -Zygote.withjacobian(f, args...) -Zygote.hessian -Zygote.hessian_reverse -Zygote.diaghessian -Zygote.pullback -``` - -## ChainRules - -Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl): - -```@docs -ChainRulesCore.ignore_derivatives -ChainRulesCore.@non_differentiable -``` - -To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works. - -```@docs -ChainRulesCore.rrule -ChainRulesCore.frule -ChainRulesCore.@scalar_rule -ChainRulesCore.NoTangent -ChainRulesCore.ZeroTangent -ChainRulesCore.RuleConfig -ChainRulesCore.Tangent -ChainRulesCore.canonicalize -``` diff --git a/ext/FluxEnzymeExt.jl b/ext/FluxEnzymeExt.jl new file mode 100644 index 0000000000..0969c854d8 --- /dev/null +++ b/ext/FluxEnzymeExt.jl @@ -0,0 +1,115 @@ +module FluxEnzymeExt + +using Flux +import Flux.Train: _enzyme_train! + +import Optimisers +import Functors +import Enzyme +using Enzyme: EnzymeCore, EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed +using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal +using ProgressLogging: @withprogress, @logprogress + +EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true + +### gradient & withgradient +function Flux.gradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}; zero::Bool=true) where {F,N} + return _enzyme_gradient(f, map(_trymake_duplicated, x)...; zero) +end + +function Flux.withgradient(f::F, adtype::AutoEnzyme, x::Vararg{Any,N}; zero::Bool=true) where {F,N} + return _enzyme_withgradient(f, map(_trymake_duplicated, x)...; zero) +end + +_trymake_duplicated(x::EnzymeCore.Duplicated) = x +_trymake_duplicated(x::EnzymeCore.Const) = x +_trymake_duplicated(x::EnzymeCore.Active) = throw(ArgumentError("Enzyme's `Active` type not supported in `Flux.gradient` or `Flux.withgradient`.")) +_trymake_duplicated(x) = EnzymeCore.Duplicated(x, EnzymeCore.make_zero(x)) + + +function _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) + for x in args + zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) + _check_mutable(x) + end + ad = Enzyme.set_runtime_activity(Reverse) + Enzyme.autodiff(ad, Const(f), Active, args...) + return map(_grad_or_nothing, args) +end + +_check_mutable(x::Const) = nothing +_check_mutable(x::Duplicated) = Functors.anymutable(x) || error( + """`Flux.gradient(f, Duplicated(x), ...)` expects `x` to contain mutable parameter arrays.""" +) + +# This function strips the returned gradient to be Zygote-like: +_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) +_grad_or_nothing(::Const) = nothing +_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing + +function _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) + for x in args + zero && x isa Duplicated && EnzymeCore.remake_zero!(x.dval) + _check_mutable(x) + end + + # In order to support auxillary outputs, we try different ways. + + ## Take I, doesn't allow for aux at all. + ad = Enzyme.set_runtime_activity(ReverseWithPrimal) + _, result = Enzyme.autodiff(ReverseWithPrimal, Const(f), Active, args...) + + ## Take II, using split mode. + ## This fails with RNNs https://github.com/EnzymeAD/Enzyme.jl/issues/2897 + # forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) + # tape, result, shadow_result = forward(Const(f), args...) + # reverse(Const(f), args..., _sensitivity(result), tape) + + ## Take III, it may be more efficient to have the function write the loss into Ref(0.0)? + ## This doesn't work with Reactant + # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0)) + # ad = Enzyme.set_runtime_activity(ReverseWithPrimal) + # _, result = autodiff(ad, Const(_ref_loss!), Const, dup_loss, Const(f), args...) + + return (; val = result, grad = map(_grad_or_nothing, args)) +end + +## for Take II above +# @inline _sensitivity(y::Real) = one(y) +# @inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...) +# @inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys))) +# _sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, +# or else a Tuple or NamedTuple whose first element is a real number.""") + +# for Take III above +# function _ref_loss!(out::Ref, f, args...) +# val = f(args...) +# out[] = _get_loss(val) # saves loss by mutation +# val # returns the whole thing +# end +# @inline _get_loss(y::Number) = y +# @inline _get_loss(ys::Tuple{Number,Vararg}) = ys[1] +# @inline _get_loss(ys::NamedTuple{S, <:Tuple{Number,Vararg}}) where S = ys[1] +# _get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, +# or else a Tuple or NamedTuple whose first element is a real number.""") + + +### Flux.Train, for train! + +function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + l, gs = Flux.withgradient(loss, AutoEnzyme(), model, map(Const, d_splat)...) + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model2 = Optimisers.update!(opt, model.val, model.dval) + model = Duplicated(model2, model.dval) + + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + +end # FluxEnzymeExt diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl deleted file mode 100644 index b124a4a973..0000000000 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ /dev/null @@ -1,129 +0,0 @@ -module FluxEnzymeExt - -using Flux -import Flux.Train: _enzyme_train! - -import Optimisers -import Functors -import Enzyme -using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed -using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal -using ProgressLogging: @withprogress, @logprogress - -EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true - -### gradient & withgradient - -# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields -# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 probably this can be `make_zero!(Ref(dup.dval))` -_make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model) -function _make_zero_inner!(x::AbstractArray{<:Number}) - Optimisers.isnumeric(x) || return - Optimisers.maywrite(x) || error("can't handle this") - fill!(x, zero(eltype(x))) - nothing -end -_make_zero_inner!(x) = nothing # any other Functors leaf type - -#= # This _make_zero! matches what Flux allows elsewhere: -julia> Flux.setup(Adam(), (1:3.)') -ERROR: model must be fully mutable for `train!` to work, got `x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}`. -If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) = true` -=# -# Perhaps canonical way for Enzyme is more like this: -# function _make_zero!(x::AbstractArray{<:Number}) -# if Enzyme.guess_activity(typeof(x), Reverse) <: Duplicated -# fill!(x, zero(eltype(x))) -# elseif Enzyme.guess_activity(typeof(x), Reverse) <: Const -# # that's OK -# else -# error("not sure what it should do for Active?") -# end -# end - -function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) - for x in args - zero && x isa Duplicated && _make_zero!(x.dval) - _check_mutable(x) - end - ad = Enzyme.set_runtime_activity(Reverse) - Enzyme.autodiff(ad, Const(f), Active, args...) - map(_grad_or_nothing, args) -end - -_check_mutable(x::Const) = nothing -_check_mutable(x::Duplicated) = Functors.anymutable(x) || error( - """`Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays.""" -) - -# This function strips the returned gradient to be Zygote-like: -_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) -_grad_or_nothing(::Const) = nothing -_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing - -function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) - for x in args - zero && x isa Duplicated && _make_zero!(x.dval) - _check_mutable(x) - end - - # Take I, doesn't allow for aux at all. - # _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) - - # Take II, using split mode. - forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) - tape, result, shadow_result = forward(Const(f), args...) - reverse(Const(f), args..., _sensitivity(result), tape) - - # Take III, it may be more efficient to have the function write the loss into Ref(0.0)? - # dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0)) - # # result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...) - # _, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...) - - (; val = result, grad = map(_grad_or_nothing, args)) -end - -@inline _sensitivity(y::Real) = one(y) -@inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...) -@inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys))) -_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, - or else a Tuple or NamedTuple whose first element is a real number.""") - -function _ref_loss!(out::Ref, f, args...) # for Take III above - val = f(args...) - out[] = _get_loss(val) # saves loss by mutation - val # returns the whole thing -end - -@inline _get_loss(y::Real) = y -@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1] -@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1] -_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber, - or else a Tuple or NamedTuple whose first element is a real number.""") - -### Flux.Train, for train! - -_applyloss(loss, model, d...) = loss(model, d...) - -function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) - - _make_zero!(model.dval) - ad = Enzyme.set_runtime_activity(ReverseWithPrimal) - _, l = Enzyme.autodiff(ad, _applyloss, - Active, Const(loss), model, map(Const, d_splat)...) - - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt, model2 = Optimisers.update!(opt, model.val, model.dval) - model = Duplicated(model2, model.dval) - - @logprogress Base.haslength(data) ? i/length(data) : nothing - end -end - -end # FluxEnzymeExt diff --git a/ext/FluxFiniteDifferencesExt.jl b/ext/FluxFiniteDifferencesExt.jl new file mode 100644 index 0000000000..836f401201 --- /dev/null +++ b/ext/FluxFiniteDifferencesExt.jl @@ -0,0 +1,33 @@ +module FluxFiniteDifferencesExt + +using Flux +using ADTypes: AutoFiniteDifferences +using FiniteDifferences + +function Flux.gradient(f::F, adtype::AutoFiniteDifferences, x) where F + ps, re = Flux.destructure(x) + gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1] + return (re(gs),) +end + +function Flux.gradient(f::F, adtype::AutoFiniteDifferences, x::Vararg{Any,N}) where {F, N} + ps, re = Flux.destructure(x) + gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1] + return re(gs) +end + +function Flux.withgradient(f::F, adtype::AutoFiniteDifferences, x) where F + ps, re = Flux.destructure(x) + y = f(re(ps)...) + gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1] + return y, (re(gs),) +end + +function Flux.withgradient(f::F, adtype::AutoFiniteDifferences, x::Vararg{Any,N}) where {F, N} + ps, re = Flux.destructure(x) + y = f(re(ps)...) + gs = FiniteDifferences.grad(adtype.fdm, p -> f(re(p)...), ps)[1] + return y, re(gs) +end + +end # module diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl new file mode 100644 index 0000000000..d371bfb952 --- /dev/null +++ b/ext/FluxMooncakeExt.jl @@ -0,0 +1,17 @@ +module FluxMooncakeExt + +using ADTypes: AutoMooncake +using Mooncake: Mooncake +import Flux + +function Flux.gradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N} + return Flux.withgradient(f, adtype, args...)[2] +end + +function Flux.withgradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N} + cache = Mooncake.prepare_gradient_cache(f, args...; friendly_tangents=true) + val, grads = Mooncake.value_and_gradient!!(cache, f, args...) + return (val=val, grad=grads[2:end]) +end + +end # module diff --git a/src/Flux.jl b/src/Flux.jl index 1fb4f0ef6a..657d9b9a1c 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -25,6 +25,8 @@ using Zygote: @adjoint, pullback using Zygote.ForwardDiff: value using EnzymeCore: EnzymeCore +@reexport using ADTypes # AutoZygote, AutoMooncake, etc... + @reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!, default_device_rng, gpu_device, cpu_device, xla_device, @@ -64,8 +66,6 @@ export Chain, Dense, Embedding, EmbeddingBag, freeze!, thaw!, adjust!, update!, trainable, # from Zygote.jl hessian, diaghessian, jacobian, withjacobian, pullback, - # AD functions - withgradient, # init glorot_uniform, glorot_normal, @@ -99,7 +99,7 @@ export Chain, Dense, Embedding, EmbeddingBag, )) include("gradient.jl") -export gradient +export gradient, withgradient include("train.jl") using .Train diff --git a/src/gradient.jl b/src/gradient.jl index aa88828545..8ee4d1eebf 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -1,6 +1,7 @@ +const SUPPORTED_AD_BACKENDS = (:Zygote, :Enzyme, :Mooncake, :FiniteDifferences) """ - gradient(f, args...) + gradient(f, [adtype,] args...) Returns a tuple containing `∂f/∂x` for each argument `x`, the derivative (for scalar `x`) or the gradient. @@ -8,13 +9,21 @@ If no gradient is defined, `∂f/∂x` will be `nothing`. `f(args...)` must be a real number, see [`Zygote.jacobian`](@ref) for array output. -By default, `Flux.gradient` calls Zygote. If you load Enzyme, then other methods become available. +The optional argument `adtype` allows specifying the automatic differentiation backend. + +We provide specific support and testing for the following backends: +`AutoZygote`, `AutoEnzyme`, `AutoMooncake`, and `AutoFiniteDifferences`. + +The package corresponding to any chosen backend (except Zygote) must be loaded in advance. + +If no `adtype` is given, then Zygote.jl is used by default, unless at least one argument +is of type `Duplicated` from Enzyme.jl, in which case Enzyme.jl is used. See also [`withgradient`](@ref) to keep the value `f(args...)`. # Examples -``` +```julia-repl julia> Flux.gradient(*, 2.0, 3.0, 5.0) (15.0, 10.0, 6.0) @@ -27,10 +36,29 @@ julia> Flux.gradient([7, 11], 0, 1) do x, y, d end ([14.0, 22.0], 2.0, nothing) ``` +Specifying other AD backends: + +```julia-repl +julia> using Mooncake + +julia> f(x) = sum(2 .* x) +f (generic function with 1 method) + +julia> Flux.gradient(f, AutoMooncake(), [1.0, 2.0, 3.0]) +([2.0, 2.0, 2.0],) +``` """ +function gradient(f, adtype::ADTypes.AbstractADType, args...) + error("AD backend has to be loaded to use `gradient(f, AutoXXX(), args...)`. + Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`. + Supported backends are $SUPPORTED_AD_BACKENDS.") +end + + +# Default gradient using Zygote function gradient(f, args...; zero::Bool=true) for a in args - a isa EnzymeCore.Duplicated && return _enzyme_gradient(f, map(_ensure_enzyme, args)...; zero) + a isa Union{EnzymeCore.Duplicated, EnzymeCore.Const} && return gradient(f, AutoEnzyme(), args...; zero) end for a in args _ensure_noenzyme(a) @@ -41,17 +69,9 @@ function gradient(f, args...; zero::Bool=true) If you are writing new code, then Zygote over Zygote is heavily discouraged. """) end - Zygote.gradient(f, args...) + return Zygote.gradient(f, args...) end -# Given one Duplicated, we wrap everything else in Const before calling Enzyme -_ensure_enzyme(x::EnzymeCore.Duplicated) = x -_ensure_enzyme(x::EnzymeCore.Const) = x -_ensure_enzyme(x) = EnzymeCore.Const(x) -_ensure_enzyme(x::EnzymeCore.Active) = throw(ArgumentError( - "The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`." -)) - # Without any Duplicated, check for no stray Enzyme types before calling Zygote _ensure_noenzyme(::EnzymeCore.Const) = throw(ArgumentError( "The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`." @@ -62,7 +82,7 @@ _ensure_noenzyme(::EnzymeCore.Active) = throw(ArgumentError( _ensure_noenzyme(_) = nothing """ - gradient(f, args::Union{Const,Duplicated}...) + gradient(f, args::Union{Any,EnzymeCore.Duplicated}...) This should return the same answer as `gradient(f, args...)`, but it uses Enzyme.jl instead of Zygote.jl to compute the derivative. @@ -70,7 +90,7 @@ but it uses Enzyme.jl instead of Zygote.jl to compute the derivative. Only available when Enzyme is loaded! This method is used when at least one argument is of type `Duplicated`, -and all unspecified aguments are wrapped in `Const`. +All non-duplicated arguments are treated as `Const`. Note that Enzyme's `Active` is not supported. Besides returning the gradient, this is also stored within the `Duplicated` object. @@ -78,12 +98,9 @@ Calling `Enzyme.Duplicated(model)` allocates space for the gradient, which is zero'd befor use when calling `gradient`. With the keyword `zero=false`, the new gradient will instead be added to what is already stored. -!!! warning "Experimental" - Enzyme support like this is new and somewhat experimental. - This method was added in Flux 0.15. +# Examples -# Example -``` +```julia-repl julia> using Flux julia> model = Chain(Dense([3.0;;])); @@ -119,28 +136,26 @@ julia> Flux.gradient(dup_model, [1]; zero=false) do m, x # implict Const([1]), ((layers = ((weight = [12.0;;], bias = [12.0], σ = nothing),),), nothing) ``` """ -gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_gradient(f, args...; zero) - -gradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError( - "The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`." -)) - -# FluxEnzymeExt defines more specific _enzyme_gradient(f, args::Union{Const, Duplicated}...; zero) -_enzyme_gradient(f, args...; zero) = throw(ArgumentError( - "Methods like `gradient(f, x::Duplicated)` are only available when Enzyme is loaded." -)) +gradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = gradient(f, AutoEnzyme(), args...; zero) """ - withgradient(f, args...) + withgradient(f, [adtype,] args...) Returns both the value of the function and the [`gradient`](@ref), as a named tuple. -By default, `Flux.withgradient` calls Zygote. If you load Enzyme, then other methods become available. +The optional argument `adtype` allows specifying the automatic differentiation backend +among the supported ones: `AutoZygote`, `AutoEnzyme`, `AutoMooncake`, and `AutoFiniteDifferences`. +The package corresponding to the chosen backend must be loaded in advance. -# Example +If no `adtype` is given, then Zygote.jl is used by default, unless at least one argument +is of type `Duplicated` from Enzyme.jl, in which case Enzyme.jl is used. -``` +Se also [`gradient`](@ref) to get just the gradient. + +# Examples + +```jldoctest julia> y, ∇ = withgradient(/, 1, 2) (val = 0.5, grad = (0.5, -0.25)) @@ -148,12 +163,12 @@ julia> ∇ == gradient(/, 1, 2) true ``` -Allows you to capture auxillary outputs, in addition to the scalar +`withgradient` allows you to capture auxillary outputs, in addition to the scalar used by `gradient`. To do this, `f` must return a Tuple or NamedTuple. Then it calculates `grad = gradient(first∘f, args...) but returns the whole `val = f(args...)`: -```jldoctest; setup=:(using Zygote) +```jldoctest julia> withgradient([1,2,4]) do x z = 1 ./ x sum(z), z # here z is an auxillary output @@ -165,10 +180,29 @@ julia> withgradient(3.0, 4.0) do x, y end (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) ``` + +Different AD backends can be specified: +```julia-repl +julia> using Mooncake + +julia> f(x) = sum(2 .* x) +f (generic function with 1 method) + +julia> Flux.withgradient(f, AutoMooncake(), [1.0, 2.0, 3.0]) +(val = 12.0, grad = ([2.0, 2.0, 2.0],)) +``` """ +function withgradient(f, adtype::ADTypes.AbstractADType, args...) + error("AD backend has to be loaded to use `withgradient(f, AutoXXX(), args...)`. + Make sure to `using` the corresponding package, e.g. `using Mooncake` for `AutoMooncake()`. + Supported backends are $SUPPORTED_AD_BACKENDS.") +end + + +# Default withgradient using Zygote function withgradient(f, args...; zero::Bool=true) for a in args - a isa EnzymeCore.Duplicated && return _enzyme_withgradient(f, map(_ensure_enzyme, args)...; zero) + a isa Union{EnzymeCore.Duplicated, EnzymeCore.Const} && return withgradient(f, AutoEnzyme(), args...; zero) end for a in args _ensure_noenzyme(a) @@ -179,22 +213,28 @@ function withgradient(f, args...; zero::Bool=true) If you are writing new code, then Zygote over Zygote is heavily discouraged. """) end - Zygote.withgradient(f, args...) + return Zygote.withgradient(f, args...) +end + +## Zygote version, supporting aux output too. +function withgradient(f::F, adtype::AutoZygote, x::Vararg{Any,N}) where {F,N} + return Zygote.withgradient(f, x...) end """ - withgradient(f, args::Union{Const,Duplicated}...) + withgradient(f, args::Union{Any,EnzymeCore.Duplicated}...) This should return the same answer as `withgradient(f, model, args...)`, but it uses Enzyme.jl instead of Zygote.jl to compute the derivative. Only available when Enzyme is loaded! -!!! warning "Experimental" - Enzyme support like this is new and somewhat experimental. - This method was added in Flux 0.15. +This method is used when at least one argument is of type `Duplicated`, +All non-duplicated arguments will be differentiated as well. +Mark them as `Const` to avoid this. +Note that Enzyme's `Active` is not supported. -# Example +# Examples ```julia-repl julia> using Flux, Enzyme @@ -210,26 +250,18 @@ julia> Flux.withgradient(m -> m(3), model) # this uses Zygote julia> Flux.withgradient(m -> m(3), Duplicated(model)) # this uses Enzyme (val = 14.52, grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) ``` - -The function `f` may return Tuple or NamedTuple, with the loss as the first element. -The gradient is then `grad = gradient(first∘f, args...)` -but the returned value is `val = f(args...)`: - -```julia-repl -julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model)) -(val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) - -julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model)) -(val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) -``` """ -withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = _enzyme_withgradient(f, args...; zero) +withgradient(f, args::Union{EnzymeCore.Const, EnzymeCore.Duplicated}...; zero::Bool=true) = withgradient(f, AutoEnzyme(), args...; zero) -withgradient(f, args::EnzymeCore.Const...; zero::Bool=true) = throw(ArgumentError( - "The method `withgradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`." -)) +## ADD BACK TO withgradient docstring above when AUX is SUPPORTED +# The function `f` may return Tuple or NamedTuple, with the loss as the first element. +# The gradient is then `grad = gradient(first∘f, args...)` +# but the returned value is `val = f(args...)`: -# FluxEnzymeExt defines more specific _enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero) -_enzyme_withgradient(f, args...; zero) = throw(ArgumentError( - "Methods like `withgradient(f, x::Duplicated)` are only available when Enzyme is loaded." -)) +# ```julia-repl +# julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model)) +# (val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) + +# julia> Flux.withgradient(m -> (loss=m(3), aux=round.(m.(1:3); digits=3)), Duplicated(model)) +# (val = (loss = 14.52, aux = [4.84, 9.68, 14.52]), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) +# ``` diff --git a/test/Project.toml b/test/Project.toml index cffbc43306..83882844d5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -12,6 +13,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -21,9 +23,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +Flux = {path = ".."} + [compat] FiniteDifferences = "0.12" Tracker = "0.2.33" - -[sources.Flux] -path = ".." diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index 8962c7bedb..6926213af4 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -19,7 +19,7 @@ end @testset "Chain of Dense layers" begin m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) x = rand(Float32, 10, 10) - test_gradients(m, x, test_gpu=true, compare_finite_diff=false) + test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end @testset "Convolution" begin @@ -75,14 +75,14 @@ end @testset "Chain(Conv)" begin m = Chain(Conv((3, 3), 3 => 3)) x = rand(Float32, 5, 5, 3, 2) - test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false) + test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) md = m |> gpu |> cpu @test md[1].weight ≈ m[1].weight atol=1f-3 m = Chain(ConvTranspose((3, 3), 3 => 3)) x = rand(Float32, 5, 5, 3, 2) - test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false) + test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) md = m |> gpu |> cpu @test md[1].weight ≈ m[1].weight atol=1f-3 @@ -91,7 +91,7 @@ end @testset "Cross-correlation" begin m = CrossCor((2, 2), 3 => 4) x = rand(Float32, 5, 5, 3, 2) - test_gradients(m, x, test_gpu=true, compare_finite_diff=false) + test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end @testset "Restructure" begin @@ -131,7 +131,7 @@ end bn = BatchNorm(3, σ) for nd in 1:3 x = rand(Float32, fill(2, nd - 1)..., 3, 4) - test_gradients(bn, x; test_gpu=true, compare_finite_diff=false) + test_gradients(bn, x; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end end diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index 35fa983806..cdbd67c33e 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -15,12 +15,12 @@ end x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) # Single Step - @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false, + @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss=cell_loss) broken = :rnncell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(r, x, h; test_gpu=true, - compare_finite_diff=false, + @test test_gradients(r, x, h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss=recurrent_cell_loss) broken = :rnncell_multiple ∈ BROKEN_TESTS end @@ -37,9 +37,11 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :rnn_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :rnn_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, batch_size) - @test test_gradients(model, x, test_gpu=true, compare_finite_diff=false) broken = :rnn ∈ BROKEN_TESTS + @test test_gradients(model, x, test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :rnn ∈ BROKEN_TESTS end @testset "LSTMCell" begin @@ -49,11 +51,12 @@ end h = zeros(Float32, d_out) c = zeros(Float32, d_out) # Single Step - @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, + @test test_gradients(cell, x[1], (h, c); test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss = cell_loss) broken = :lstmcell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(cell, x, (h, c); test_gpu=true, - compare_finite_diff = false, + @test test_gradients(cell, x, (h, c); test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss = recurrent_cell_loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS end @@ -71,11 +74,11 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, - compare_finite_diff=false) broken = :lstm_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :lstm_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, - compare_finite_diff=false) broken = :lstm ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :lstm ∈ BROKEN_TESTS end @testset "GRUCell" begin @@ -83,11 +86,11 @@ end r = GRUCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) - @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false, + @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss = cell_loss) broken = :grucell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, - compare_finite_diff = false, + @test test_gradients(r, x, h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS end @@ -104,11 +107,11 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, - compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :gru_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, - compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :gru ∈ BROKEN_TESTS end @testset "GRUv3Cell GPU AD" begin @@ -116,11 +119,11 @@ end r = GRUv3Cell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) - @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false, + @test test_gradients(r, x[1], h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss=cell_loss) broken = :gruv3cell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, - compare_finite_diff=false, + @test test_gradients(r, x, h; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing, loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS end @@ -137,9 +140,9 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, - compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :gruv3_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, - compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) broken = :gruv3 ∈ BROKEN_TESTS end diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index b7b456bcf1..f6b44e74e1 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -13,19 +13,14 @@ end const ACTIVATIONS = [identity, tanh] function gpu_gradtest(name::String, layers::Vector, x_cpu, args...; - test_mode=false, test_grad_x=true, + test_mode=false, atol=1e-4, rtol=1e-4) @testset "$name GPU grad tests" begin for layer in layers @testset "$layer Layer GPU grad test" begin - - # compute output and grad of parameters l_cpu = layer(args...) - if test_mode - testmode!(l_cpu) - end - - test_gradients(l_cpu, x_cpu; test_gpu=true, compare_finite_diff=false, test_grad_x, atol, rtol) + test_gradients(l_cpu, x_cpu; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing, + atol, rtol, test_mode) end end end @@ -90,19 +85,22 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3) gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3) embedding = [Flux.Embedding] -gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2, test_grad_x=false) -gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2, test_grad_x=false) -gpu_gradtest("Embedding integer index", embedding, 1, 5, 2, test_grad_x=false) -gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2, test_grad_x=false) -gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2, test_grad_x=false) -gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2, test_grad_x=false) -gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2, test_grad_x=false) +gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2) +gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2) +gpu_gradtest("Embedding integer index", embedding, 1, 5, 2) +gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2) +gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2) +gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2) +gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2) @testset "function layers" begin x = rand(Float32, 3, 3) - test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, compare_finite_diff=false) - test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, compare_finite_diff=false) - test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, compare_finite_diff=false) + test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) + test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) + test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, test_cpu=false, + reference=AutoZygote(), compare=nothing) end @testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv) @@ -183,7 +181,7 @@ end @test size(b(x, y)) == (3,9) @test sum(abs2, b(x, y)) ≈ 0f0 test_gradients(b |> cpu, x |> cpu, y |> cpu, - test_gpu=true, compare_finite_diff=false, loss=(m, x, y) -> mean(abs2, m(x, y))) + test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end @testset "Two-streams Bilinear" begin @@ -193,7 +191,7 @@ end @test size(b(x, y)) == (3,9) @test sum(abs2, b(x, y)) ≈ 0f0 test_gradients(b |> cpu, x |> cpu, y |> cpu, - test_gpu=true, compare_finite_diff=false, loss=(m, x, y) -> mean(abs2, m(x, y))) + test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end @testset "Parallel" begin @@ -213,7 +211,7 @@ end @testset "gradient" begin layer_cpu = Parallel(+, x -> zero(x), identity) test_gradients(layer_cpu, randn(2, 2, 2, 2), - test_gpu=true, compare_finite_diff=false, loss=(m, x) -> mean(abs2, m(x))) + test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end end @@ -294,5 +292,5 @@ end return sum(y.^2) + sum(α.^2) end test_gradients(mha_cpu, x_cpu; loss, - test_gpu=true, compare_finite_diff=false) + test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) end diff --git a/test/ext_cuda/losses.jl b/test/ext_cuda/losses.jl index 11e14981b7..514b71634c 100644 --- a/test/ext_cuda/losses.jl +++ b/test/ext_cuda/losses.jl @@ -30,7 +30,7 @@ y = [1 0 0 0 1 y = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4) @test loss(x, y) ≈ loss(gpu(x), gpu(y)) - test_gradients(loss, x, y, test_gpu=true, test_grad_f=false, compare_finite_diff=false) + test_gradients(loss, x, y, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) # Float16 tests @test loss(f16(x), f16(y)) ≈ loss(gpu(f16(x)), gpu(f16(y))) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index e93568afe8..68645445f7 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,52 +1,9 @@ # ENZYME CPU TESTS -@testset "Models" begin - function loss(model, x) - mean(model(x)) - end - - models_xs = [ - (Dense(2=>4), randn(Float32, 2), "Dense"), - (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), - (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), - (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), - (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), - (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), - (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), - (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), - (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - (first ∘ LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), - (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), - (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), - ] - - for (model, x, name) in models_xs +@testset "enzyme gradients" begin + for (model, x, name) in TEST_MODELS @testset "Enzyme grad check $name" begin - println("testing $name with Enzyme") - @test test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true) - end - end -end - -@testset "Recurrent Layers" begin - function loss(model, x) - mean(model(x)) - end - - models_xs = [ - (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), - ] - - for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name") - test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true) + @test test_gradients(model, x; reference=AutoZygote(), compare=AutoEnzyme()) end end end @@ -61,25 +18,26 @@ end @test g1.bias == [1, 1] @test m1.dval.bias == [1, 1] - g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, [1,2,3f0]) + g2 = Flux.withgradient((m,x) -> sum(m(x)), m1, Const([1,2,3f0])) @test g2.val ≈ sum(m1([1,2,3f0])) @test g2.grad[1].weight ≈ [1 2 3; 1 2 3] - @test g2.grad[2] === nothing # implicitly Const + @test g2.grad[2] === nothing - g3 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x - z = 1 ./ x - sum(z), z # here z is an auxillary output - end - @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625] - @test g3.val[1] ≈ 1.75 - @test g3.val[2] ≈ [1.0, 0.5, 0.25] - g4 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x - z = 1 ./ x - (loss=sum(z), aux=string(z)) - end - @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625] - @test g4.val.loss ≈ 1.75 - @test g4.val.aux == "[1.0, 0.5, 0.25]" + ## Auxillary outputs not supported at the moment + # g3 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # sum(z), z # here z is an auxillary output + # end + # @test g3.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g3.val[1] ≈ 1.75 + # @test g3.val[2] ≈ [1.0, 0.5, 0.25] + # g4 = Flux.withgradient(Duplicated([1,2,4.], zeros(3))) do x + # z = 1 ./ x + # (loss=sum(z), aux=string(z)) + # end + # @test g4.grad[1] ≈ [-1.0, -0.25, -0.0625] + # @test g4.val.loss ≈ 1.75 + # @test g4.val.aux == "[1.0, 0.5, 0.25]" # setup understands Duplicated: @test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val) @@ -93,11 +51,11 @@ end m1.val.weight .= 0 @test Flux.loadmodel!(m1, oldpar).val.weight ≈ oldpar.weight - # At least one Duplicated is required: - @test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val)) - @test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) - @test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val)) - @test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) + # Only Const args are supported + @test Flux.gradient(m -> sum(m.bias), Const(m1.val))[1] === nothing + @test Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]) isa Tuple{Nothing,Vector{Float32}} + @test Flux.withgradient(m -> sum(m.bias), Const(m1.val)).grad[1] === nothing + @test Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0]).grad isa Tuple{Nothing,Vector{Float32}} # Active is disallowed: @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0)) @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0)) @@ -115,7 +73,6 @@ end @test Flux.gradient(sum ∘ LayerNorm(3), z)[1] ≈ [0.0, 0.0, 0.0] @test Flux.gradient(|>, z, _duplicated(sum ∘ LayerNorm(3)))[1] ≈ [0.0, 0.0, 0.0] @test Flux.gradient(|>, z, Const(sum ∘ LayerNorm(3)))[2] === nothing - - @test_broken Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] # AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any} - @test_broken Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0] + @test Flux.withgradient(sum ∘ LayerNorm(3), z).grad[1] ≈ [0.0, 0.0, 0.0] + @test Flux.withgradient(|>, z, _duplicated(sum ∘ LayerNorm(3))).grad[1] ≈ [0.0, 0.0, 0.0] end diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl index 9febd8e455..b041a9272f 100644 --- a/test/ext_metal/basic.jl +++ b/test/ext_metal/basic.jl @@ -21,5 +21,15 @@ end m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) x = rand(Float32, 10, 10) @test (m|>gpu)(x|>gpu) isa MtlArray{Float32, 2} - test_gradients(m, x, test_gpu=true, compare_finite_diff=false) + test_gradients(m, x, test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) +end + +@testset "gradients" begin + broken_models = ["Conv", "Chain(Conv, Conv)", "Chain(Conv, MeanPool)", "ConvTranspose","Bilinear","MultiHeadAttention"] + # Bilinear and MultiHeadAttention will be fixed by https://github.com/FluxML/NNlib.jl/pull/614 + for (model, x, name) in TEST_MODELS + @testset "Zygote grad check $name" begin + @test test_gradients(model, x; test_gpu=true, test_cpu=false, reference=AutoZygote(), compare=nothing) broken=(name ∈ broken_models) + end + end end diff --git a/test/ext_mooncake.jl b/test/ext_mooncake.jl new file mode 100644 index 0000000000..f60633e981 --- /dev/null +++ b/test/ext_mooncake.jl @@ -0,0 +1,7 @@ +@testset "mooncake gradient" begin + for (model, x, name) in TEST_MODELS + @testset "grad check $name" begin + @test test_gradients(model, x; reference=AutoZygote(), compare=AutoMooncake()) + end + end +end diff --git a/test/ext_reactant/reactant.jl b/test/ext_reactant/reactant.jl index e6fd9f48c9..820c0dd579 100644 --- a/test/ext_reactant/reactant.jl +++ b/test/ext_reactant/reactant.jl @@ -1,78 +1,8 @@ -function scalarfirst(x) - Reactant.@allowscalar first(x) -end - @testset "Reactant Models" begin - function loss(model, x) - mean(model(x)) - end - - models_xs = [ - (Dense(2=>4), randn(Float32, 2), "Dense"), - - (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), - - (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), - - (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), - - (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), - - (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), - - (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), - - (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), - - (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), - - (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), - - (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), - - (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), - ] - - for (model, x, name) in models_xs - @testset "Enzyme grad check $name" begin - println("testing $name with Reactant") - test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true) - end - end - - models_xs = [ - (LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar - - (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # Zygote comparison test fails on the GPUArraysCore.@allowscalar in scalarfirst, so we globally allow scalar - ] - - Reactant.allowscalar(true) - for (model, x, name) in models_xs - @testset "Enzyme grad check $name" begin - println("testing $name with Reactant") - test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true) - end - end - Reactant.allowscalar(false) -end - -@testset "Reactant Recurrent Layers" begin - function loss(model, x) - mean(model(x)) - end - - models_xs = [ - (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), - (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), - (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), - (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), - (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), - ] - - for (model, x, name) in models_xs - @testset "check grad $name" begin - println("testing $name with Reactant") - test_gradients(model, x; loss, compare_finite_diff=false, test_reactant=true) + broken_models = () + for (model, x, name) in TEST_MODELS + @testset "Reactant grad check $name" begin + @test test_gradients(model, x; reference=AutoZygote(), test_reactant=true, test_cpu=false) broken=(name ∈ broken_models) end end end diff --git a/test/ext_reactant/test_utils_reactant.jl b/test/ext_reactant/test_utils_reactant.jl index 6b3cec53d8..60e879d2bb 100644 --- a/test/ext_reactant/test_utils_reactant.jl +++ b/test/ext_reactant/test_utils_reactant.jl @@ -2,7 +2,7 @@ # because Reactant is only optionally loaded and the macros fail when it is not loaded. function reactant_withgradient(f, x...) - y, g = Reactant.@jit enzyme_withgradient(f, x...) + y, g = Reactant.@jit Flux.withgradient(f, AutoEnzyme(), x...) return y, g end diff --git a/test/runtests.jl b/test/runtests.jl index 301d46dc1f..f4a0b7facf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,9 +6,10 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector, using Flux.Losses: xlogx, xlogy using Flux.Losses using ForwardDiff: ForwardDiff -using Functors: Functors, fmapstructure_with_path +using Functors: Functors, fmapstructure_with_path, fmap using IterTools: ncycle using LinearAlgebra +using Mooncake: Mooncake using MLUtils: MLUtils, batch, unstack, unsqueeze, unbatch, getobs, numobs, flatten, DataLoader using Optimisers: Optimisers @@ -18,8 +19,6 @@ using SparseArrays using Statistics using Test using Zygote: Zygote -# const gradient = Flux.gradient # both Flux & Zygote export this on 0.15 -# const withgradient = Flux.withgradient ## Uncomment below to change the default test settings # ENV["FLUX_TEST_AMDGPU"] = "true" @@ -104,6 +103,10 @@ end include("functors.jl") end + @testset "mooncake" begin + include("ext_mooncake.jl") + end + @testset "deprecations" begin include("deprecations.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl index 3576d314e2..d4cc62e9f0 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -12,27 +12,27 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.siamese_contrastive_loss] -function finitediff_withgradient(f, x...) - y = f(x...) - # We set a range to avoid domain errors - fdm = FiniteDifferences.central_fdm(5, 1, max_range=1e-2) - return y, FiniteDifferences.grad(fdm, f, x...) -end - -function enzyme_withgradient(f, x...) - args = [] - for x in x - if x isa Number - push!(args, Enzyme.Active(x)) - else - push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x))) - end - end - ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal) - ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...) - g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) - return ret[2], g -end +const TEST_MODELS = [ + (Dense(2=>4), randn(Float32, 2), "Dense"), + (Chain(Dense(2=>4, tanh), Dense(4=>3)), randn(Float32, 2), "Chain(Dense, Dense)"), + (f64(Chain(Dense(2=>4), Dense(4=>2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), + (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), + (Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), + (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"), + (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), + (ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"), + (LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"), + (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), + (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), + (RNN(3 => 2), randn(Float32, 3, 2), "RNN"), + (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"), + (GRU(3 => 5), randn(Float32, 3, 10), "GRU"), + (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"), + (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"), +] function _contains_no_numerical(kp, x) count = 0 @@ -46,132 +46,108 @@ function _contains_no_numerical(kp, x) end function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) + # Since Zygote could use nothing for an entire subtree, we prune the + # the tree using _contains_no_numerical fmapstructure_with_path(a, b, exclude=_contains_no_numerical) do kp, x, y - if y isa Nothing - return - end # @show kp - if x isa AbstractArray - @test x ≈ y rtol=rtol atol=atol - elseif x isa Number + if x isa AbstractArray{<:AbstractFloat} @test x ≈ y rtol=rtol atol=atol end end + return true +end + +function _contains_no_numerical(kp, x) + count = 0 + fmap(x) do y + if y isa AbstractArray{<:AbstractFloat} + count += 1 + end + return y + end + return count == 0 end -# By default, this computes the gradients on cpu using the default AD (Zygote) -# and compares them with finite differences. -# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth -# and test other scenarios. +_default_fdm() = FiniteDifferences.central_fdm(5, 1, max_range=1e-2) + +""" +Compare the `reference` and `compare` AD backends on the gradients of `f` at `xs...`. +The loss function can be customized (default is mean over outputs). + +- If `test_gpu` is true, the `compare` backend is tested on GPU. +- If `test_cpu` is true, the `compare` backend is tested on CPU. +- If `test_reactant` is true, the Enzyme backend is tested with Reactant. + Depending on the platform, this may run on CPU or GPU. +""" function test_gradients( f, xs...; rtol=1e-4, atol=1e-4, test_gpu = false, + test_cpu = true, test_reactant = false, - test_enzyme = false, - test_grad_f = true, - test_grad_x = true, - compare_finite_diff = true, + reference = AutoFiniteDifferences(; fdm = _default_fdm()), + compare = AutoZygote(), loss = (f, xs...) -> mean(f(xs...)), + test_mode = false, ) - if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant - error("You should either compare numerical gradients methods or CPU vs GPU.") - end + @assert reference !== nothing "reference AD backend must be provided" + @assert compare !== nothing || test_gpu "compare AD backend must be provided if test_gpu=false" + compare = compare === nothing ? reference : compare - Flux.trainmode!(f) # for layers like BatchNorm + if test_mode + Flux.testmode!(f) + else + Flux.trainmode!(f) + end - ## Let's make sure first that the forward pass works. - l = loss(f, xs...) - @test l isa Number + cpu_dev = cpu_device() + if test_gpu gpu_dev = gpu_device(force=true) cpu_dev = cpu_device() xs_gpu = xs |> gpu_dev f_gpu = f |> gpu_dev - l_gpu = loss(f_gpu, xs_gpu...) - @test l_gpu isa Number end - + if test_reactant reactant_dev = MLDataDevices.reactant_device(force=true) - cpu_dev = cpu_device() xs_re = xs |> reactant_dev f_re = f |> reactant_dev - l_re = reactant_loss(loss, f_re, xs_re...) - @test l ≈ l_re rtol=rtol atol=atol end - if test_grad_x - # Zygote gradient with respect to input. - y, g = Zygote.withgradient((xs...) -> loss(f, xs...), xs...) - - if compare_finite_diff - # Cast to Float64 to avoid precision issues. - f64 = f |> Flux.f64 - xs64 = xs .|> Flux.f64 - y_fd, g_fd = finitediff_withgradient((xs...) -> loss(f64, xs...), xs64...) - @test y ≈ y_fd rtol=rtol atol=atol - check_equal_leaves(g, g_fd; rtol, atol) - end - - if test_enzyme - y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...) - @test y ≈ y_ez rtol=rtol atol=atol - check_equal_leaves(g, g_ez; rtol, atol) - end + ## Let's make sure first that the forward pass works. + l = loss(f, xs...) + @test l isa Number - if test_gpu - # Zygote gradient with respect to input on GPU. - y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, xs...), xs_gpu...) - @test get_device(g_gpu) == get_device(xs_gpu) - @test y_gpu ≈ y rtol=rtol atol=atol - check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) - end + # Compute reference gradients in f64 precision + y, gs = Flux.withgradient(loss, reference, Flux.f64(f), Flux.f64(xs)...) + @test l ≈ y rtol=rtol atol=atol - if test_reactant - # Enzyme gradient with respect to input on Reactant. - y_re, g_re = reactant_withgradient(Base.Fix1(loss, f_re), xs_re...) - @test y ≈ y_re rtol=rtol atol=atol - check_equal_leaves(g_re |> cpu_dev, g; rtol, atol) - end + if test_cpu + y2, gs2 = Flux.withgradient(loss, compare, f, xs...) + @test l ≈ y2 rtol=rtol atol=atol + check_equal_leaves(gs, gs2; rtol, atol) end - if test_grad_f - # Zygote gradient with respect to f. - y, g = Zygote.withgradient(f -> loss(f, xs...), f) - - if compare_finite_diff - # Cast to Float64 to avoid precision issues. - f64 = f |> Flux.f64 - ps, re = Flux.destructure(f64) - y_fd, g_fd = finitediff_withgradient(ps -> loss(re(ps), xs...), ps) - g_fd = (re(g_fd[1]),) - @test y ≈ y_fd rtol=rtol atol=atol - check_equal_leaves(g, g_fd; rtol, atol) - end + if test_gpu + l_gpu = loss(f_gpu, xs_gpu...) + @test l_gpu isa Number - if test_enzyme - y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f) - @test y ≈ y_ez rtol=rtol atol=atol - check_equal_leaves(g, g_ez; rtol, atol) - end + y_gpu, gs_gpu = Flux.withgradient(loss, compare, f_gpu, xs_gpu...) + @test l_gpu ≈ y_gpu rtol=rtol atol=atol + check_equal_leaves(gs, gs_gpu |> cpu_dev; rtol, atol) + end - if test_gpu - # Zygote gradient with respect to f on GPU. - y_gpu, g_gpu = Zygote.withgradient(f -> loss(f, xs_gpu...), f_gpu) - # @test get_device(g_gpu) == get_device(xs_gpu) - @test y_gpu ≈ y rtol=rtol atol=atol - check_equal_leaves(g_gpu |> cpu_dev, g; rtol, atol) - end + if test_reactant + l_re = reactant_loss(loss, f_re, xs_re...) + @test l ≈ l_re rtol=rtol atol=atol - if test_reactant - # Enzyme gradient with respect to input on Reactant. - y_re, g_re = reactant_withgradient(Base.Fix2(loss, xs_re[1]), f_re) - @test y ≈ y_re rtol=rtol atol=atol - check_equal_leaves(g_re |> cpu_dev, g; rtol, atol) - end + y_re, g_re = reactant_withgradient(loss, f_re, xs_re...) + @test y ≈ y_re rtol=rtol atol=atol + check_equal_leaves(gs, g_re |> cpu_dev; rtol, atol) end + return true end